|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
|
|
import torch |
|
|
|
|
|
def main(args): |
|
sd = torch.load(args.src, map_location="cpu")["model"] |
|
sd = {k: v for k, v in sd.items() if "teacher" not in k} |
|
sd = { |
|
k.replace("backbone.vision_backbone", "image_encoder"): v for k, v in sd.items() |
|
} |
|
sd = {k.replace("mlp.fc1", "mlp.layers.0"): v for k, v in sd.items()} |
|
sd = {k.replace("mlp.fc2", "mlp.layers.1"): v for k, v in sd.items()} |
|
sd = {k.replace("convs", "neck.convs"): v for k, v in sd.items()} |
|
sd = { |
|
k.replace("transformer.encoder", "memory_attention"): v for k, v in sd.items() |
|
} |
|
sd = {k.replace("maskmem_backbone", "memory_encoder"): v for k, v in sd.items()} |
|
sd = {k.replace("maskmem_backbone", "memory_encoder"): v for k, v in sd.items()} |
|
sd = {k.replace("mlp.lin1", "mlp.layers.0"): v for k, v in sd.items()} |
|
sd = {k.replace("mlp.lin2", "mlp.layers.1"): v for k, v in sd.items()} |
|
torch.save({"model": sd}, args.src.replace(".pt", "_converted.pt")) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--src", type=str, required=True) |
|
args = parser.parse_args() |
|
|
|
main(args) |
|
|