|
import torch |
|
import wget |
|
|
|
def preprocess(model, name='dino', embed_dim=384): |
|
new_model = {} |
|
for k in model.keys(): |
|
if 'patch_embed.proj.weight' in k: |
|
x = torch.zeros(embed_dim, 4, 16, 16) |
|
x[:, :3] = model[k] |
|
new_model['backbone.'+k] = x |
|
else: |
|
new_model['backbone.'+k] = model[k] |
|
if embed_dim==384: |
|
size='s' |
|
else: |
|
size='b' |
|
torch.save(new_model, name+'_vit_'+ size + '_fna.pth') |
|
|
|
if __name__ == "__main__": |
|
|
|
wget.download('https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth') |
|
wget.download('https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth') |
|
|
|
dino_model = torch.load('dino_deitsmall16_pretrain.pth') |
|
mae_model = torch.load('mae_pretrain_vit_base.pth')['model'] |
|
preprocess(dino_model, 'dino', 384) |
|
preprocess(mae_model, 'mae', 768) |