EdgeTAM / convert_weights.py
chongzhou's picture
add demo code
72780d8
raw
history blame contribute delete
1.36 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import torch
def main(args):
sd = torch.load(args.src, map_location="cpu")["model"]
sd = {k: v for k, v in sd.items() if "teacher" not in k}
sd = {
k.replace("backbone.vision_backbone", "image_encoder"): v for k, v in sd.items()
}
sd = {k.replace("mlp.fc1", "mlp.layers.0"): v for k, v in sd.items()}
sd = {k.replace("mlp.fc2", "mlp.layers.1"): v for k, v in sd.items()}
sd = {k.replace("convs", "neck.convs"): v for k, v in sd.items()}
sd = {
k.replace("transformer.encoder", "memory_attention"): v for k, v in sd.items()
}
sd = {k.replace("maskmem_backbone", "memory_encoder"): v for k, v in sd.items()}
sd = {k.replace("maskmem_backbone", "memory_encoder"): v for k, v in sd.items()}
sd = {k.replace("mlp.lin1", "mlp.layers.0"): v for k, v in sd.items()}
sd = {k.replace("mlp.lin2", "mlp.layers.1"): v for k, v in sd.items()}
torch.save({"model": sd}, args.src.replace(".pt", "_converted.pt"))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--src", type=str, required=True)
args = parser.parse_args()
main(args)