Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import argparse | |
import torch | |
def main(args): | |
sd = torch.load(args.src, map_location="cpu")["model"] | |
sd = {k: v for k, v in sd.items() if "teacher" not in k} | |
sd = { | |
k.replace("backbone.vision_backbone", "image_encoder"): v for k, v in sd.items() | |
} | |
sd = {k.replace("mlp.fc1", "mlp.layers.0"): v for k, v in sd.items()} | |
sd = {k.replace("mlp.fc2", "mlp.layers.1"): v for k, v in sd.items()} | |
sd = {k.replace("convs", "neck.convs"): v for k, v in sd.items()} | |
sd = { | |
k.replace("transformer.encoder", "memory_attention"): v for k, v in sd.items() | |
} | |
sd = {k.replace("maskmem_backbone", "memory_encoder"): v for k, v in sd.items()} | |
sd = {k.replace("maskmem_backbone", "memory_encoder"): v for k, v in sd.items()} | |
sd = {k.replace("mlp.lin1", "mlp.layers.0"): v for k, v in sd.items()} | |
sd = {k.replace("mlp.lin2", "mlp.layers.1"): v for k, v in sd.items()} | |
torch.save({"model": sd}, args.src.replace(".pt", "_converted.pt")) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--src", type=str, required=True) | |
args = parser.parse_args() | |
main(args) | |