# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import argparse import torch def main(args): sd = torch.load(args.src, map_location="cpu")["model"] sd = {k: v for k, v in sd.items() if "teacher" not in k} sd = { k.replace("backbone.vision_backbone", "image_encoder"): v for k, v in sd.items() } sd = {k.replace("mlp.fc1", "mlp.layers.0"): v for k, v in sd.items()} sd = {k.replace("mlp.fc2", "mlp.layers.1"): v for k, v in sd.items()} sd = {k.replace("convs", "neck.convs"): v for k, v in sd.items()} sd = { k.replace("transformer.encoder", "memory_attention"): v for k, v in sd.items() } sd = {k.replace("maskmem_backbone", "memory_encoder"): v for k, v in sd.items()} sd = {k.replace("maskmem_backbone", "memory_encoder"): v for k, v in sd.items()} sd = {k.replace("mlp.lin1", "mlp.layers.0"): v for k, v in sd.items()} sd = {k.replace("mlp.lin2", "mlp.layers.1"): v for k, v in sd.items()} torch.save({"model": sd}, args.src.replace(".pt", "_converted.pt")) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--src", type=str, required=True) args = parser.parse_args() main(args)