Spaces:
Running
Running
Commit
·
2908104
1
Parent(s):
5199d23
Added all of the code for new space
Browse files- app.py +136 -0
- models/__init__.py +192 -0
- models/beit.py +598 -0
- models/clip/__init__.py +1 -0
- models/clip/bpe_simple_vocab_16e6.txt.gz +3 -0
- models/clip/clip.py +229 -0
- models/clip/model.py +577 -0
- models/clip/simple_tokenizer.py +132 -0
- models/deploy.py +389 -0
- models/protonet.py +51 -0
- models/resnet_v2.py +164 -0
- models/utils.py +238 -0
- models/vision_transformer.py +246 -0
- models/vit_google.py +372 -0
- requirements.txt +14 -0
app.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import time
|
4 |
+
import random
|
5 |
+
import torch
|
6 |
+
import torchvision.transforms as transforms
|
7 |
+
import gradio as gr
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
|
10 |
+
from models import get_model
|
11 |
+
from dotmap import DotMap
|
12 |
+
from PIL import Image
|
13 |
+
|
14 |
+
#os.environ['TERM'] = 'linux'
|
15 |
+
#os.environ['TERMINFO'] = '/etc/terminfo'
|
16 |
+
|
17 |
+
# args
|
18 |
+
args = DotMap()
|
19 |
+
args.deploy = 'vanilla'
|
20 |
+
args.arch = 'dino_small_patch16'
|
21 |
+
args.no_pretrain = True
|
22 |
+
args.resume = 'https://huggingface.co/hushell/pmf_dinosmall_lr1e-4/resolve/main/best_converted.pth'
|
23 |
+
args.api_key = 'AIzaSyAFkOGnXhy-2ZB0imDvNNqf2rHb98vR_qY'
|
24 |
+
args.cx = '06d75168141bc47f1'
|
25 |
+
|
26 |
+
|
27 |
+
# model
|
28 |
+
device = 'cpu' #torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
29 |
+
model = get_model(args)
|
30 |
+
model.to(device)
|
31 |
+
checkpoint = torch.hub.load_state_dict_from_url(args.resume, map_location='cpu')
|
32 |
+
model.load_state_dict(checkpoint['model'], strict=True)
|
33 |
+
|
34 |
+
|
35 |
+
# image transforms
|
36 |
+
def test_transform():
|
37 |
+
def _convert_image_to_rgb(im):
|
38 |
+
return im.convert('RGB')
|
39 |
+
|
40 |
+
return transforms.Compose([
|
41 |
+
transforms.Resize(256),
|
42 |
+
transforms.CenterCrop(224),
|
43 |
+
_convert_image_to_rgb,
|
44 |
+
transforms.ToTensor(),
|
45 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
46 |
+
std=[0.229, 0.224, 0.225]),
|
47 |
+
])
|
48 |
+
|
49 |
+
preprocess = test_transform()
|
50 |
+
|
51 |
+
@torch.no_grad()
|
52 |
+
def denormalize(x, mean, std):
|
53 |
+
# 3, H, W
|
54 |
+
t = x.clone()
|
55 |
+
t.mul_(std).add_(mean)
|
56 |
+
return torch.clamp(t, 0, 1)
|
57 |
+
|
58 |
+
|
59 |
+
# Gradio UI
|
60 |
+
def inference(query, class1_name="class1", support_imgs=None, class2_name="class2", support_imgs2=None):
|
61 |
+
'''
|
62 |
+
query: PIL image
|
63 |
+
labels: list of class names
|
64 |
+
'''
|
65 |
+
|
66 |
+
|
67 |
+
#first, open the images
|
68 |
+
support_imgs = [Image.open(img) for img in support_imgs]
|
69 |
+
support_imgs2 = [Image.open(img) for img in support_imgs2]
|
70 |
+
|
71 |
+
|
72 |
+
labels = [class1_name, class2_name]
|
73 |
+
|
74 |
+
supp_x = []
|
75 |
+
supp_y = []
|
76 |
+
|
77 |
+
for i, (class_name, support_img) in enumerate(zip([class1_name, class2_name], [support_imgs, support_imgs2])):
|
78 |
+
for img in support_img:
|
79 |
+
x_im = preprocess(img)
|
80 |
+
supp_x.append(x_im)
|
81 |
+
supp_y.append(i)
|
82 |
+
|
83 |
+
|
84 |
+
supp_x = torch.stack(supp_x, dim=0).unsqueeze(0).to(device) # (1, n_supp*n_labels, 3, H, W)
|
85 |
+
supp_y = torch.tensor(supp_y).long().unsqueeze(0).to(device) # (1, n_supp*n_labels)
|
86 |
+
|
87 |
+
query = preprocess(query).unsqueeze(0).unsqueeze(0).to(device) # (1, 3, H, W)
|
88 |
+
|
89 |
+
print(f"Shape of supp_x: {supp_x.shape}")
|
90 |
+
print(f"Shape of supp_y: {supp_y.shape}")
|
91 |
+
print(f"Shape of query: {query.shape}")
|
92 |
+
|
93 |
+
|
94 |
+
|
95 |
+
with torch.cuda.amp.autocast(True):
|
96 |
+
output = model(supp_x, supp_y, query) # (1, 1, n_labels)
|
97 |
+
|
98 |
+
probs = output.softmax(dim=-1).detach().cpu().numpy()
|
99 |
+
|
100 |
+
return {k: float(v) for k, v in zip(labels, probs[0, 0])}
|
101 |
+
|
102 |
+
|
103 |
+
# DEBUG
|
104 |
+
##query = Image.open('../labrador-puppy.jpg')
|
105 |
+
#query = Image.open('/Users/hushell/Documents/Dan_tr.png')
|
106 |
+
##labels = 'dog, cat'
|
107 |
+
#labels = 'girl, sussie'
|
108 |
+
#output = inference(query, labels, n_supp=2)
|
109 |
+
#print(output)
|
110 |
+
|
111 |
+
|
112 |
+
title = "P>M>F few-shot learning pipeline"
|
113 |
+
description = "Short description: We take a ViT-small backbone, which is pre-trained with DINO, and meta-trained on Meta-Dataset; for few-shot classification, we use a ProtoNet classifier. The demo can be viewed as zero-shot since the support set is built by searching images from Google. Note that you may need to play with GIS parameters to get good support examples. Besides, GIS is not very stable as search requests may fail for many reasons (e.g., number of requests reaches the limit of the day). This code is heavely inspired from the original HF space <a href='https://huggingface.co/spaces/hushell/pmf_with_gis' target='_blank'>here</a>"
|
114 |
+
article = "<p style='text-align: center'><a href='http://arxiv.org/abs/2204.07305' target='_blank'>Arxiv</a></p>"
|
115 |
+
|
116 |
+
|
117 |
+
gr.Interface(fn=inference,
|
118 |
+
inputs=[
|
119 |
+
gr.Image(label="Image to classify", type="pil"),
|
120 |
+
#gr.Textbox(lines=1, label="Class hypotheses:", placeholder="Enter class names separated by ','",),
|
121 |
+
|
122 |
+
gr.Textbox(lines=1, label="First class name :", placeholder="Enter first class name",),
|
123 |
+
gr.File(label="Drag or select one or more photos of the first class", file_types=["image"], file_count="multiple"),
|
124 |
+
|
125 |
+
gr.Textbox(lines=1, label="Second class name :", placeholder="Enter second class name",),
|
126 |
+
gr.File(label="Drag or select one or more photos of the second class", file_types=["image"], file_count="multiple"),
|
127 |
+
],
|
128 |
+
theme="grass",
|
129 |
+
outputs=[
|
130 |
+
gr.Label(label="Predicted class probabilities"),
|
131 |
+
#gr.Image(type='pil', label="Support examples from Google image search"),
|
132 |
+
],
|
133 |
+
title=title,
|
134 |
+
description=description,
|
135 |
+
article=article,
|
136 |
+
).launch(debug=True)
|
models/__init__.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
#from timm.models import create_model
|
5 |
+
from .protonet import ProtoNet
|
6 |
+
from .deploy import ProtoNet_Finetune, ProtoNet_Auto_Finetune, ProtoNet_AdaTok, ProtoNet_AdaTok_EntMin
|
7 |
+
|
8 |
+
|
9 |
+
def get_backbone(args):
|
10 |
+
if args.arch == 'vit_base_patch16_224_in21k':
|
11 |
+
from .vit_google import VisionTransformer, CONFIGS
|
12 |
+
|
13 |
+
config = CONFIGS['ViT-B_16']
|
14 |
+
model = VisionTransformer(config, 224)
|
15 |
+
|
16 |
+
url = 'https://storage.googleapis.com/vit_models/imagenet21k/ViT-B_16.npz'
|
17 |
+
pretrained_weights = 'pretrained_ckpts/vit_base_patch16_224_in21k.npz'
|
18 |
+
|
19 |
+
if not os.path.exists(pretrained_weights):
|
20 |
+
try:
|
21 |
+
import wget
|
22 |
+
os.makedirs('pretrained_ckpts', exist_ok=True)
|
23 |
+
wget.download(url, pretrained_weights)
|
24 |
+
except:
|
25 |
+
print(f'Cannot download pretrained weights from {url}. Check if `pip install wget` works.')
|
26 |
+
|
27 |
+
model.load_from(np.load(pretrained_weights))
|
28 |
+
print('Pretrained weights found at {}'.format(pretrained_weights))
|
29 |
+
|
30 |
+
elif args.arch == 'dino_base_patch16':
|
31 |
+
from . import vision_transformer as vit
|
32 |
+
|
33 |
+
model = vit.__dict__['vit_base'](patch_size=16, num_classes=0)
|
34 |
+
url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
|
35 |
+
state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)
|
36 |
+
|
37 |
+
model.load_state_dict(state_dict, strict=True)
|
38 |
+
print('Pretrained weights found at {}'.format(url))
|
39 |
+
|
40 |
+
elif args.arch == 'deit_base_patch16':
|
41 |
+
from . import vision_transformer as vit
|
42 |
+
|
43 |
+
model = vit.__dict__['vit_base'](patch_size=16, num_classes=0)
|
44 |
+
url = "https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth"
|
45 |
+
state_dict = torch.hub.load_state_dict_from_url(url=url)["model"]
|
46 |
+
|
47 |
+
for k in ['head.weight', 'head.bias']:
|
48 |
+
if k in state_dict:
|
49 |
+
print(f"removing key {k} from pretrained checkpoint")
|
50 |
+
del state_dict[k]
|
51 |
+
|
52 |
+
model.load_state_dict(state_dict, strict=True)
|
53 |
+
print('Pretrained weights found at {}'.format(url))
|
54 |
+
|
55 |
+
elif args.arch == 'deit_small_patch16':
|
56 |
+
from . import vision_transformer as vit
|
57 |
+
|
58 |
+
model = vit.__dict__['vit_small'](patch_size=16, num_classes=0)
|
59 |
+
url = "https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth"
|
60 |
+
state_dict = torch.hub.load_state_dict_from_url(url=url)["model"]
|
61 |
+
|
62 |
+
for k in ['head.weight', 'head.bias']:
|
63 |
+
if k in state_dict:
|
64 |
+
print(f"removing key {k} from pretrained checkpoint")
|
65 |
+
del state_dict[k]
|
66 |
+
|
67 |
+
model.load_state_dict(state_dict, strict=True)
|
68 |
+
print('Pretrained weights found at {}'.format(url))
|
69 |
+
|
70 |
+
elif args.arch == 'dino_small_patch16':
|
71 |
+
from . import vision_transformer as vit
|
72 |
+
|
73 |
+
model = vit.__dict__['vit_small'](patch_size=16, num_classes=0)
|
74 |
+
|
75 |
+
if not args.no_pretrain:
|
76 |
+
url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
|
77 |
+
state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)
|
78 |
+
|
79 |
+
model.load_state_dict(state_dict, strict=True)
|
80 |
+
print('Pretrained weights found at {}'.format(url))
|
81 |
+
|
82 |
+
elif args.arch == 'beit_base_patch16_224_pt22k':
|
83 |
+
from .beit import default_pretrained_model
|
84 |
+
model = default_pretrained_model(args)
|
85 |
+
print('Pretrained BEiT loaded')
|
86 |
+
|
87 |
+
elif args.arch == 'clip_base_patch16_224':
|
88 |
+
from . import clip
|
89 |
+
model, _ = clip.load('ViT-B/16', 'cpu')
|
90 |
+
|
91 |
+
elif args.arch == 'clip_resnet50':
|
92 |
+
from . import clip
|
93 |
+
model, _ = clip.load('RN50', 'cpu')
|
94 |
+
|
95 |
+
elif args.arch == 'dino_resnet50':
|
96 |
+
from torchvision.models.resnet import resnet50
|
97 |
+
|
98 |
+
model = resnet50(pretrained=False)
|
99 |
+
model.fc = torch.nn.Identity()
|
100 |
+
|
101 |
+
if not args.no_pretrain:
|
102 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
103 |
+
url="https://dl.fbaipublicfiles.com/dino/dino_resnet50_pretrain/dino_resnet50_pretrain.pth",
|
104 |
+
map_location="cpu",
|
105 |
+
)
|
106 |
+
model.load_state_dict(state_dict, strict=False)
|
107 |
+
|
108 |
+
elif args.arch == 'resnet50':
|
109 |
+
from torchvision.models.resnet import resnet50
|
110 |
+
|
111 |
+
pretrained = not args.no_pretrain
|
112 |
+
model = resnet50(pretrained=pretrained)
|
113 |
+
model.fc = torch.nn.Identity()
|
114 |
+
|
115 |
+
elif args.arch == 'resnet18':
|
116 |
+
from torchvision.models.resnet import resnet18
|
117 |
+
|
118 |
+
pretrained = not args.no_pretrain
|
119 |
+
model = resnet18(pretrained=pretrained)
|
120 |
+
model.fc = torch.nn.Identity()
|
121 |
+
|
122 |
+
elif args.arch == 'dino_xcit_medium_24_p16':
|
123 |
+
model = torch.hub.load('facebookresearch/xcit:main', 'xcit_medium_24_p16')
|
124 |
+
model.head = torch.nn.Identity()
|
125 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
126 |
+
url="https://dl.fbaipublicfiles.com/dino/dino_xcit_medium_24_p16_pretrain/dino_xcit_medium_24_p16_pretrain.pth",
|
127 |
+
map_location="cpu",
|
128 |
+
)
|
129 |
+
model.load_state_dict(state_dict, strict=False)
|
130 |
+
|
131 |
+
elif args.arch == 'dino_xcit_medium_24_p8':
|
132 |
+
model = torch.hub.load('facebookresearch/dino:main', 'dino_xcit_medium_24_p8')
|
133 |
+
|
134 |
+
elif args.arch == 'simclrv2_resnet50':
|
135 |
+
import sys
|
136 |
+
sys.path.insert(
|
137 |
+
0,
|
138 |
+
'cog',
|
139 |
+
)
|
140 |
+
import model_utils
|
141 |
+
|
142 |
+
model_utils.MODELS_ROOT_DIR = 'cog/models'
|
143 |
+
ckpt_file = os.path.join(args.pretrained_checkpoint_path, 'pretrained_ckpts/simclrv2_resnet50.pth')
|
144 |
+
resnet, _ = model_utils.load_pretrained_backbone(args.arch, ckpt_file)
|
145 |
+
|
146 |
+
class Wrapper(torch.nn.Module):
|
147 |
+
def __init__(self, model):
|
148 |
+
super(Wrapper, self).__init__()
|
149 |
+
self.model = model
|
150 |
+
|
151 |
+
def forward(self, x):
|
152 |
+
return self.model(x, apply_fc=False)
|
153 |
+
|
154 |
+
model = Wrapper(resnet)
|
155 |
+
|
156 |
+
elif args.arch in ['mocov2_resnet50', 'swav_resnet50', 'barlow_resnet50']:
|
157 |
+
from torchvision.models.resnet import resnet50
|
158 |
+
|
159 |
+
model = resnet50(pretrained=False)
|
160 |
+
ckpt_file = os.path.join(args.pretrained_checkpoint_path, 'pretrained_ckpts_converted/{}.pth'.format(args.arch))
|
161 |
+
ckpt = torch.load(ckpt_file)
|
162 |
+
|
163 |
+
msg = model.load_state_dict(ckpt, strict=False)
|
164 |
+
assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
|
165 |
+
|
166 |
+
# remove the fully-connected layer
|
167 |
+
model.fc = torch.nn.Identity()
|
168 |
+
|
169 |
+
else:
|
170 |
+
raise ValueError(f'{args.arch} is not conisdered in the current code.')
|
171 |
+
|
172 |
+
return model
|
173 |
+
|
174 |
+
|
175 |
+
def get_model(args):
|
176 |
+
backbone = get_backbone(args)
|
177 |
+
|
178 |
+
if args.deploy == 'vanilla':
|
179 |
+
model = ProtoNet(backbone)
|
180 |
+
elif args.deploy == 'finetune':
|
181 |
+
model = ProtoNet_Finetune(backbone, args.ada_steps, args.ada_lr, args.aug_prob, args.aug_types)
|
182 |
+
elif args.deploy == 'finetune_autolr':
|
183 |
+
model = ProtoNet_Auto_Finetune(backbone, args.ada_steps, args.aug_prob, args.aug_types)
|
184 |
+
elif args.deploy == 'ada_tokens':
|
185 |
+
model = ProtoNet_AdaTok(backbone, args.num_adapters,
|
186 |
+
args.ada_steps, args.ada_lr)
|
187 |
+
elif args.deploy == 'ada_tokens_entmin':
|
188 |
+
model = ProtoNet_AdaTok_EntMin(backbone, args.num_adapters,
|
189 |
+
args.ada_steps, args.ada_lr)
|
190 |
+
else:
|
191 |
+
raise ValueError(f'deploy method {args.deploy} is not supported.')
|
192 |
+
return model
|
models/beit.py
ADDED
@@ -0,0 +1,598 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
|
3 |
+
# Github source: https://github.com/microsoft/unilm/tree/master/beit
|
4 |
+
# Copyright (c) 2021 Microsoft
|
5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
6 |
+
# By Hangbo Bao
|
7 |
+
# Based on timm and DeiT code bases
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
9 |
+
# https://github.com/facebookresearch/deit/
|
10 |
+
# https://github.com/facebookresearch/dino
|
11 |
+
# --------------------------------------------------------'
|
12 |
+
import math
|
13 |
+
from functools import partial
|
14 |
+
from scipy import interpolate
|
15 |
+
|
16 |
+
import numpy as np
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
import torch.nn.functional as F
|
20 |
+
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
|
21 |
+
#from timm.models.registry import register_model
|
22 |
+
|
23 |
+
|
24 |
+
def _cfg(url='', **kwargs):
|
25 |
+
return {
|
26 |
+
'url': url,
|
27 |
+
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
28 |
+
'crop_pct': .9, 'interpolation': 'bicubic',
|
29 |
+
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
|
30 |
+
**kwargs
|
31 |
+
}
|
32 |
+
|
33 |
+
|
34 |
+
class DropPath(nn.Module):
|
35 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
36 |
+
"""
|
37 |
+
def __init__(self, drop_prob=None):
|
38 |
+
super(DropPath, self).__init__()
|
39 |
+
self.drop_prob = drop_prob
|
40 |
+
|
41 |
+
def forward(self, x):
|
42 |
+
return drop_path(x, self.drop_prob, self.training)
|
43 |
+
|
44 |
+
def extra_repr(self) -> str:
|
45 |
+
return 'p={}'.format(self.drop_prob)
|
46 |
+
|
47 |
+
|
48 |
+
class Mlp(nn.Module):
|
49 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
50 |
+
super().__init__()
|
51 |
+
out_features = out_features or in_features
|
52 |
+
hidden_features = hidden_features or in_features
|
53 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
54 |
+
self.act = act_layer()
|
55 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
56 |
+
self.drop = nn.Dropout(drop)
|
57 |
+
|
58 |
+
def forward(self, x):
|
59 |
+
x = self.fc1(x)
|
60 |
+
x = self.act(x)
|
61 |
+
# x = self.drop(x)
|
62 |
+
# commit this for the orignal BERT implement
|
63 |
+
x = self.fc2(x)
|
64 |
+
x = self.drop(x)
|
65 |
+
return x
|
66 |
+
|
67 |
+
|
68 |
+
class Attention(nn.Module):
|
69 |
+
def __init__(
|
70 |
+
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
|
71 |
+
proj_drop=0., window_size=None, attn_head_dim=None):
|
72 |
+
super().__init__()
|
73 |
+
self.num_heads = num_heads
|
74 |
+
head_dim = dim // num_heads
|
75 |
+
if attn_head_dim is not None:
|
76 |
+
head_dim = attn_head_dim
|
77 |
+
all_head_dim = head_dim * self.num_heads
|
78 |
+
self.scale = qk_scale or head_dim ** -0.5
|
79 |
+
|
80 |
+
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
|
81 |
+
if qkv_bias:
|
82 |
+
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
|
83 |
+
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
|
84 |
+
else:
|
85 |
+
self.q_bias = None
|
86 |
+
self.v_bias = None
|
87 |
+
|
88 |
+
if window_size:
|
89 |
+
self.window_size = window_size
|
90 |
+
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
|
91 |
+
self.relative_position_bias_table = nn.Parameter(
|
92 |
+
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
93 |
+
# cls to token & token 2 cls & cls to cls
|
94 |
+
|
95 |
+
# get pair-wise relative position index for each token inside the window
|
96 |
+
coords_h = torch.arange(window_size[0])
|
97 |
+
coords_w = torch.arange(window_size[1])
|
98 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
99 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
100 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
101 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
102 |
+
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
103 |
+
relative_coords[:, :, 1] += window_size[1] - 1
|
104 |
+
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
105 |
+
relative_position_index = \
|
106 |
+
torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
|
107 |
+
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
108 |
+
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
109 |
+
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
110 |
+
relative_position_index[0, 0] = self.num_relative_distance - 1
|
111 |
+
|
112 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
113 |
+
else:
|
114 |
+
self.window_size = None
|
115 |
+
self.relative_position_bias_table = None
|
116 |
+
self.relative_position_index = None
|
117 |
+
|
118 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
119 |
+
self.proj = nn.Linear(all_head_dim, dim)
|
120 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
121 |
+
|
122 |
+
def forward(self, x, rel_pos_bias=None):
|
123 |
+
B, N, C = x.shape
|
124 |
+
qkv_bias = None
|
125 |
+
if self.q_bias is not None:
|
126 |
+
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
|
127 |
+
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
128 |
+
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
129 |
+
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
130 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
131 |
+
|
132 |
+
q = q * self.scale
|
133 |
+
attn = (q @ k.transpose(-2, -1))
|
134 |
+
|
135 |
+
if self.relative_position_bias_table is not None:
|
136 |
+
relative_position_bias = \
|
137 |
+
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
138 |
+
self.window_size[0] * self.window_size[1] + 1,
|
139 |
+
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
|
140 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
141 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
142 |
+
|
143 |
+
if rel_pos_bias is not None:
|
144 |
+
attn = attn + rel_pos_bias
|
145 |
+
|
146 |
+
attn = attn.softmax(dim=-1)
|
147 |
+
attn = self.attn_drop(attn)
|
148 |
+
|
149 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
150 |
+
x = self.proj(x)
|
151 |
+
x = self.proj_drop(x)
|
152 |
+
return x
|
153 |
+
|
154 |
+
|
155 |
+
class Block(nn.Module):
|
156 |
+
|
157 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
158 |
+
drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
|
159 |
+
window_size=None, attn_head_dim=None):
|
160 |
+
super().__init__()
|
161 |
+
self.norm1 = norm_layer(dim)
|
162 |
+
self.attn = Attention(
|
163 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
164 |
+
attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)
|
165 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
166 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
167 |
+
self.norm2 = norm_layer(dim)
|
168 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
169 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
170 |
+
|
171 |
+
if init_values > 0:
|
172 |
+
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
|
173 |
+
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
|
174 |
+
else:
|
175 |
+
self.gamma_1, self.gamma_2 = None, None
|
176 |
+
|
177 |
+
def forward(self, x, rel_pos_bias=None):
|
178 |
+
if self.gamma_1 is None:
|
179 |
+
x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
|
180 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
181 |
+
else:
|
182 |
+
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
|
183 |
+
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
|
184 |
+
return x
|
185 |
+
|
186 |
+
|
187 |
+
class PatchEmbed(nn.Module):
|
188 |
+
""" Image to Patch Embedding
|
189 |
+
"""
|
190 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
191 |
+
super().__init__()
|
192 |
+
img_size = to_2tuple(img_size)
|
193 |
+
patch_size = to_2tuple(patch_size)
|
194 |
+
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
195 |
+
self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
196 |
+
self.img_size = img_size
|
197 |
+
self.patch_size = patch_size
|
198 |
+
self.num_patches = num_patches
|
199 |
+
|
200 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
201 |
+
|
202 |
+
def forward(self, x, **kwargs):
|
203 |
+
B, C, H, W = x.shape
|
204 |
+
# FIXME look at relaxing size constraints
|
205 |
+
assert H == self.img_size[0] and W == self.img_size[1], \
|
206 |
+
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
207 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
208 |
+
return x
|
209 |
+
|
210 |
+
|
211 |
+
class RelativePositionBias(nn.Module):
|
212 |
+
|
213 |
+
def __init__(self, window_size, num_heads):
|
214 |
+
super().__init__()
|
215 |
+
self.window_size = window_size
|
216 |
+
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
|
217 |
+
self.relative_position_bias_table = nn.Parameter(
|
218 |
+
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
219 |
+
# cls to token & token 2 cls & cls to cls
|
220 |
+
|
221 |
+
# get pair-wise relative position index for each token inside the window
|
222 |
+
coords_h = torch.arange(window_size[0])
|
223 |
+
coords_w = torch.arange(window_size[1])
|
224 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
225 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
226 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
227 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
228 |
+
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
229 |
+
relative_coords[:, :, 1] += window_size[1] - 1
|
230 |
+
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
231 |
+
relative_position_index = \
|
232 |
+
torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
|
233 |
+
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
234 |
+
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
235 |
+
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
236 |
+
relative_position_index[0, 0] = self.num_relative_distance - 1
|
237 |
+
|
238 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
239 |
+
|
240 |
+
# trunc_normal_(self.relative_position_bias_table, std=.02)
|
241 |
+
|
242 |
+
def forward(self):
|
243 |
+
relative_position_bias = \
|
244 |
+
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
245 |
+
self.window_size[0] * self.window_size[1] + 1,
|
246 |
+
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
|
247 |
+
return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
248 |
+
|
249 |
+
|
250 |
+
class VisionTransformer(nn.Module):
|
251 |
+
""" Vision Transformer with support for patch or hybrid CNN input stage
|
252 |
+
"""
|
253 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
254 |
+
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
|
255 |
+
drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None,
|
256 |
+
use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False,
|
257 |
+
use_mean_pooling=True, init_scale=0.001):
|
258 |
+
super().__init__()
|
259 |
+
self.num_classes = num_classes
|
260 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
261 |
+
|
262 |
+
self.patch_embed = PatchEmbed(
|
263 |
+
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
264 |
+
num_patches = self.patch_embed.num_patches
|
265 |
+
|
266 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
267 |
+
# self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
268 |
+
if use_abs_pos_emb:
|
269 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
270 |
+
else:
|
271 |
+
self.pos_embed = None
|
272 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
273 |
+
|
274 |
+
if use_shared_rel_pos_bias:
|
275 |
+
self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
|
276 |
+
else:
|
277 |
+
self.rel_pos_bias = None
|
278 |
+
|
279 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
280 |
+
self.use_rel_pos_bias = use_rel_pos_bias
|
281 |
+
self.blocks = nn.ModuleList([
|
282 |
+
Block(
|
283 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
284 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
285 |
+
init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)
|
286 |
+
for i in range(depth)])
|
287 |
+
self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
|
288 |
+
self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
|
289 |
+
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
290 |
+
|
291 |
+
if self.pos_embed is not None:
|
292 |
+
trunc_normal_(self.pos_embed, std=.02)
|
293 |
+
trunc_normal_(self.cls_token, std=.02)
|
294 |
+
# trunc_normal_(self.mask_token, std=.02)
|
295 |
+
self.apply(self._init_weights)
|
296 |
+
self.fix_init_weight()
|
297 |
+
|
298 |
+
if num_classes > 0:
|
299 |
+
trunc_normal_(self.head.weight, std=.02)
|
300 |
+
self.head.weight.data.mul_(init_scale)
|
301 |
+
self.head.bias.data.mul_(init_scale)
|
302 |
+
|
303 |
+
def fix_init_weight(self):
|
304 |
+
def rescale(param, layer_id):
|
305 |
+
param.div_(math.sqrt(2.0 * layer_id))
|
306 |
+
|
307 |
+
for layer_id, layer in enumerate(self.blocks):
|
308 |
+
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
309 |
+
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
|
310 |
+
|
311 |
+
def _init_weights(self, m):
|
312 |
+
if isinstance(m, nn.Linear):
|
313 |
+
trunc_normal_(m.weight, std=.02)
|
314 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
315 |
+
nn.init.constant_(m.bias, 0)
|
316 |
+
elif isinstance(m, nn.LayerNorm):
|
317 |
+
nn.init.constant_(m.bias, 0)
|
318 |
+
nn.init.constant_(m.weight, 1.0)
|
319 |
+
|
320 |
+
def get_num_layers(self):
|
321 |
+
return len(self.blocks)
|
322 |
+
|
323 |
+
@torch.jit.ignore
|
324 |
+
def no_weight_decay(self):
|
325 |
+
return {'pos_embed', 'cls_token'}
|
326 |
+
|
327 |
+
def get_classifier(self):
|
328 |
+
return self.head
|
329 |
+
|
330 |
+
def reset_classifier(self, num_classes, global_pool=''):
|
331 |
+
self.num_classes = num_classes
|
332 |
+
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
333 |
+
|
334 |
+
def forward_features(self, x):
|
335 |
+
x = self.patch_embed(x)
|
336 |
+
batch_size, seq_len, _ = x.size()
|
337 |
+
|
338 |
+
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
339 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
340 |
+
if self.pos_embed is not None:
|
341 |
+
x = x + self.pos_embed
|
342 |
+
x = self.pos_drop(x)
|
343 |
+
|
344 |
+
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
|
345 |
+
for blk in self.blocks:
|
346 |
+
x = blk(x, rel_pos_bias=rel_pos_bias)
|
347 |
+
|
348 |
+
x = self.norm(x)
|
349 |
+
if self.fc_norm is not None:
|
350 |
+
t = x[:, 1:, :]
|
351 |
+
return self.fc_norm(t.mean(1))
|
352 |
+
else:
|
353 |
+
return x[:, 0]
|
354 |
+
|
355 |
+
def forward(self, x):
|
356 |
+
x = self.forward_features(x)
|
357 |
+
x = self.head(x)
|
358 |
+
return x
|
359 |
+
|
360 |
+
|
361 |
+
#@register_model
|
362 |
+
def beit_base_patch16_224(pretrained=False, **kwargs):
|
363 |
+
model = VisionTransformer(
|
364 |
+
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
365 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
366 |
+
model.default_cfg = _cfg()
|
367 |
+
return model
|
368 |
+
|
369 |
+
|
370 |
+
#@register_model
|
371 |
+
def beit_base_patch16_384(pretrained=False, **kwargs):
|
372 |
+
model = VisionTransformer(
|
373 |
+
img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
374 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
375 |
+
model.default_cfg = _cfg()
|
376 |
+
return model
|
377 |
+
|
378 |
+
|
379 |
+
#@register_model
|
380 |
+
def beit_large_patch16_224(pretrained=False, **kwargs):
|
381 |
+
model = VisionTransformer(
|
382 |
+
patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
383 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
384 |
+
model.default_cfg = _cfg()
|
385 |
+
return model
|
386 |
+
|
387 |
+
|
388 |
+
#@register_model
|
389 |
+
def beit_large_patch16_384(pretrained=False, **kwargs):
|
390 |
+
model = VisionTransformer(
|
391 |
+
img_size=384, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
392 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
393 |
+
model.default_cfg = _cfg()
|
394 |
+
return model
|
395 |
+
|
396 |
+
|
397 |
+
#@register_model
|
398 |
+
def beit_large_patch16_512(pretrained=False, **kwargs):
|
399 |
+
model = VisionTransformer(
|
400 |
+
img_size=512, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
401 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
402 |
+
model.default_cfg = _cfg()
|
403 |
+
return model
|
404 |
+
|
405 |
+
|
406 |
+
def load_state_dict(model, state_dict, prefix='', ignore_missing="relative_position_index"):
|
407 |
+
missing_keys = []
|
408 |
+
unexpected_keys = []
|
409 |
+
error_msgs = []
|
410 |
+
# copy state_dict so _load_from_state_dict can modify it
|
411 |
+
metadata = getattr(state_dict, '_metadata', None)
|
412 |
+
state_dict = state_dict.copy()
|
413 |
+
if metadata is not None:
|
414 |
+
state_dict._metadata = metadata
|
415 |
+
|
416 |
+
def _load(module, prefix=''):
|
417 |
+
local_metadata = {} if metadata is None else metadata.get(
|
418 |
+
prefix[:-1], {})
|
419 |
+
module._load_from_state_dict(
|
420 |
+
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
|
421 |
+
for name, child in module._modules.items():
|
422 |
+
if child is not None:
|
423 |
+
_load(child, prefix + name + '.')
|
424 |
+
|
425 |
+
_load(model, prefix=prefix)
|
426 |
+
|
427 |
+
warn_missing_keys = []
|
428 |
+
ignore_missing_keys = []
|
429 |
+
for key in missing_keys:
|
430 |
+
keep_flag = True
|
431 |
+
for ignore_key in ignore_missing.split('|'):
|
432 |
+
if ignore_key in key:
|
433 |
+
keep_flag = False
|
434 |
+
break
|
435 |
+
if keep_flag:
|
436 |
+
warn_missing_keys.append(key)
|
437 |
+
else:
|
438 |
+
ignore_missing_keys.append(key)
|
439 |
+
|
440 |
+
missing_keys = warn_missing_keys
|
441 |
+
|
442 |
+
if len(missing_keys) > 0:
|
443 |
+
print("Weights of {} not initialized from pretrained model: {}".format(
|
444 |
+
model.__class__.__name__, missing_keys))
|
445 |
+
if len(unexpected_keys) > 0:
|
446 |
+
print("Weights from pretrained model not used in {}: {}".format(
|
447 |
+
model.__class__.__name__, unexpected_keys))
|
448 |
+
if len(ignore_missing_keys) > 0:
|
449 |
+
print("Ignored weights of {} not initialized from pretrained model: {}".format(
|
450 |
+
model.__class__.__name__, ignore_missing_keys))
|
451 |
+
if len(error_msgs) > 0:
|
452 |
+
print('\n'.join(error_msgs))
|
453 |
+
|
454 |
+
|
455 |
+
def default_pretrained_model(args):
|
456 |
+
model = beit_base_patch16_224(
|
457 |
+
pretrained=False,
|
458 |
+
img_size=args.image_size,
|
459 |
+
num_classes=0,
|
460 |
+
drop_rate=0.,
|
461 |
+
drop_path_rate=0.1,
|
462 |
+
attn_drop_rate=0.,
|
463 |
+
#drop_block_rate=None,
|
464 |
+
use_mean_pooling=True,
|
465 |
+
init_scale=0.001,
|
466 |
+
use_rel_pos_bias=True,
|
467 |
+
use_abs_pos_emb=False,
|
468 |
+
init_values=0.1,
|
469 |
+
)
|
470 |
+
|
471 |
+
#url = 'https://unilm.blob.core.windows.net/beit/beit_base_patch16_224_pt22k.pth'
|
472 |
+
url = 'https://unilm.blob.core.windows.net/beit/beit_base_patch16_224_pt22k_ft22k.pth'
|
473 |
+
|
474 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
475 |
+
url, map_location='cpu', check_hash=True)
|
476 |
+
print('Pretrained weights found at {}'.format(url))
|
477 |
+
|
478 |
+
# select key
|
479 |
+
checkpoint_model = None
|
480 |
+
for model_key in ['model', 'module']:
|
481 |
+
if model_key in checkpoint:
|
482 |
+
checkpoint_model = checkpoint[model_key]
|
483 |
+
print("Load state_dict by model_key = %s" % model_key)
|
484 |
+
break
|
485 |
+
if checkpoint_model is None:
|
486 |
+
checkpoint_model = checkpoint
|
487 |
+
|
488 |
+
# remove head
|
489 |
+
state_dict = model.state_dict()
|
490 |
+
for k in ['head.weight', 'head.bias']:
|
491 |
+
#if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
|
492 |
+
if k in checkpoint_model:
|
493 |
+
print(f"Removing key {k} from pretrained checkpoint")
|
494 |
+
del checkpoint_model[k]
|
495 |
+
|
496 |
+
# resize rel_pos_bias
|
497 |
+
if model.use_rel_pos_bias and "rel_pos_bias.relative_position_bias_table" in checkpoint_model:
|
498 |
+
print("Expand the shared relative position embedding to each transformer block. ")
|
499 |
+
num_layers = model.get_num_layers()
|
500 |
+
rel_pos_bias = checkpoint_model["rel_pos_bias.relative_position_bias_table"]
|
501 |
+
for i in range(num_layers):
|
502 |
+
checkpoint_model["blocks.%d.attn.relative_position_bias_table" % i] = rel_pos_bias.clone()
|
503 |
+
|
504 |
+
checkpoint_model.pop("rel_pos_bias.relative_position_bias_table")
|
505 |
+
|
506 |
+
all_keys = list(checkpoint_model.keys())
|
507 |
+
for key in all_keys:
|
508 |
+
if "relative_position_index" in key:
|
509 |
+
checkpoint_model.pop(key)
|
510 |
+
|
511 |
+
if "relative_position_bias_table" in key:
|
512 |
+
rel_pos_bias = checkpoint_model[key]
|
513 |
+
src_num_pos, num_attn_heads = rel_pos_bias.size()
|
514 |
+
dst_num_pos, _ = model.state_dict()[key].size()
|
515 |
+
dst_patch_shape = model.patch_embed.patch_shape
|
516 |
+
if dst_patch_shape[0] != dst_patch_shape[1]:
|
517 |
+
raise NotImplementedError()
|
518 |
+
num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1)
|
519 |
+
src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
|
520 |
+
dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5)
|
521 |
+
if src_size != dst_size:
|
522 |
+
print("Position interpolate for %s from %dx%d to %dx%d" % (
|
523 |
+
key, src_size, src_size, dst_size, dst_size))
|
524 |
+
extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
|
525 |
+
rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
|
526 |
+
|
527 |
+
def geometric_progression(a, r, n):
|
528 |
+
return a * (1.0 - r ** n) / (1.0 - r)
|
529 |
+
|
530 |
+
left, right = 1.01, 1.5
|
531 |
+
while right - left > 1e-6:
|
532 |
+
q = (left + right) / 2.0
|
533 |
+
gp = geometric_progression(1, q, src_size // 2)
|
534 |
+
if gp > dst_size // 2:
|
535 |
+
right = q
|
536 |
+
else:
|
537 |
+
left = q
|
538 |
+
|
539 |
+
# if q > 1.090307:
|
540 |
+
# q = 1.090307
|
541 |
+
|
542 |
+
dis = []
|
543 |
+
cur = 1
|
544 |
+
for i in range(src_size // 2):
|
545 |
+
dis.append(cur)
|
546 |
+
cur += q ** (i + 1)
|
547 |
+
|
548 |
+
r_ids = [-_ for _ in reversed(dis)]
|
549 |
+
|
550 |
+
x = r_ids + [0] + dis
|
551 |
+
y = r_ids + [0] + dis
|
552 |
+
|
553 |
+
t = dst_size // 2.0
|
554 |
+
dx = np.arange(-t, t + 0.1, 1.0)
|
555 |
+
dy = np.arange(-t, t + 0.1, 1.0)
|
556 |
+
|
557 |
+
print("Original positions = %s" % str(x))
|
558 |
+
print("Target positions = %s" % str(dx))
|
559 |
+
|
560 |
+
all_rel_pos_bias = []
|
561 |
+
|
562 |
+
for i in range(num_attn_heads):
|
563 |
+
z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()
|
564 |
+
f = interpolate.interp2d(x, y, z, kind='cubic')
|
565 |
+
all_rel_pos_bias.append(
|
566 |
+
torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device))
|
567 |
+
|
568 |
+
rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
|
569 |
+
|
570 |
+
new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
|
571 |
+
checkpoint_model[key] = new_rel_pos_bias
|
572 |
+
|
573 |
+
# interpolate position embedding
|
574 |
+
if 'pos_embed' in checkpoint_model:
|
575 |
+
pos_embed_checkpoint = checkpoint_model['pos_embed']
|
576 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
577 |
+
num_patches = model.patch_embed.num_patches
|
578 |
+
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
|
579 |
+
# height (== width) for the checkpoint position embedding
|
580 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
581 |
+
# height (== width) for the new position embedding
|
582 |
+
new_size = int(num_patches ** 0.5)
|
583 |
+
# class_token and dist_token are kept unchanged
|
584 |
+
if orig_size != new_size:
|
585 |
+
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
|
586 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
587 |
+
# only the position tokens are interpolated
|
588 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
589 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
590 |
+
pos_tokens = torch.nn.functional.interpolate(
|
591 |
+
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
592 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
593 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
594 |
+
checkpoint_model['pos_embed'] = new_pos_embed
|
595 |
+
|
596 |
+
load_state_dict(model, checkpoint_model)
|
597 |
+
return model
|
598 |
+
|
models/clip/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .clip import *
|
models/clip/bpe_simple_vocab_16e6.txt.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
|
3 |
+
size 1356917
|
models/clip/clip.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
import os
|
3 |
+
import urllib
|
4 |
+
import warnings
|
5 |
+
from typing import Any, Union, List
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from PIL import Image
|
9 |
+
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
from .model import build_model, build_vision_model
|
13 |
+
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
|
14 |
+
|
15 |
+
try:
|
16 |
+
from torchvision.transforms import InterpolationMode
|
17 |
+
BICUBIC = InterpolationMode.BICUBIC
|
18 |
+
except ImportError:
|
19 |
+
BICUBIC = Image.BICUBIC
|
20 |
+
|
21 |
+
|
22 |
+
if torch.__version__.split(".") < ["1", "7", "1"]:
|
23 |
+
warnings.warn("PyTorch version 1.7.1 or higher is recommended")
|
24 |
+
|
25 |
+
|
26 |
+
__all__ = ["available_models", "load", "tokenize"]
|
27 |
+
_tokenizer = _Tokenizer()
|
28 |
+
|
29 |
+
_MODELS = {
|
30 |
+
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
|
31 |
+
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
|
32 |
+
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
|
33 |
+
"RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
|
34 |
+
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
|
35 |
+
"ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
|
36 |
+
}
|
37 |
+
|
38 |
+
|
39 |
+
def _download(url: str, root: str):
|
40 |
+
os.makedirs(root, exist_ok=True)
|
41 |
+
filename = os.path.basename(url)
|
42 |
+
|
43 |
+
expected_sha256 = url.split("/")[-2]
|
44 |
+
download_target = os.path.join(root, filename)
|
45 |
+
|
46 |
+
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
47 |
+
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
48 |
+
|
49 |
+
if os.path.isfile(download_target):
|
50 |
+
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
|
51 |
+
return download_target
|
52 |
+
else:
|
53 |
+
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
54 |
+
|
55 |
+
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
56 |
+
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
|
57 |
+
while True:
|
58 |
+
buffer = source.read(8192)
|
59 |
+
if not buffer:
|
60 |
+
break
|
61 |
+
|
62 |
+
output.write(buffer)
|
63 |
+
loop.update(len(buffer))
|
64 |
+
|
65 |
+
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
|
66 |
+
raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
|
67 |
+
|
68 |
+
return download_target
|
69 |
+
|
70 |
+
|
71 |
+
def _convert_image_to_rgb(image):
|
72 |
+
return image.convert("RGB")
|
73 |
+
|
74 |
+
|
75 |
+
def _transform(n_px):
|
76 |
+
return Compose([
|
77 |
+
Resize(n_px, interpolation=BICUBIC),
|
78 |
+
CenterCrop(n_px),
|
79 |
+
_convert_image_to_rgb,
|
80 |
+
ToTensor(),
|
81 |
+
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
82 |
+
])
|
83 |
+
|
84 |
+
|
85 |
+
def available_models() -> List[str]:
|
86 |
+
"""Returns the names of available CLIP models"""
|
87 |
+
return list(_MODELS.keys())
|
88 |
+
|
89 |
+
|
90 |
+
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
|
91 |
+
"""Load a CLIP model
|
92 |
+
|
93 |
+
Parameters
|
94 |
+
----------
|
95 |
+
name : str
|
96 |
+
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
|
97 |
+
|
98 |
+
device : Union[str, torch.device]
|
99 |
+
The device to put the loaded model
|
100 |
+
|
101 |
+
jit : bool
|
102 |
+
Whether to load the optimized JIT model or more hackable non-JIT model (default).
|
103 |
+
|
104 |
+
download_root: str
|
105 |
+
path to download the model files; by default, it uses "~/.cache/clip"
|
106 |
+
|
107 |
+
Returns
|
108 |
+
-------
|
109 |
+
model : torch.nn.Module
|
110 |
+
The CLIP model
|
111 |
+
|
112 |
+
preprocess : Callable[[PIL.Image], torch.Tensor]
|
113 |
+
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
|
114 |
+
"""
|
115 |
+
if name in _MODELS:
|
116 |
+
model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
|
117 |
+
elif os.path.isfile(name):
|
118 |
+
model_path = name
|
119 |
+
else:
|
120 |
+
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
121 |
+
|
122 |
+
try:
|
123 |
+
# loading JIT archive
|
124 |
+
model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
|
125 |
+
state_dict = None
|
126 |
+
except RuntimeError:
|
127 |
+
# loading saved state dict
|
128 |
+
if jit:
|
129 |
+
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
|
130 |
+
jit = False
|
131 |
+
state_dict = torch.load(model_path, map_location="cpu")
|
132 |
+
|
133 |
+
if not jit:
|
134 |
+
#model = build_model(state_dict or model.state_dict()).to(device)
|
135 |
+
model = build_vision_model(state_dict or model.state_dict()).to(device)
|
136 |
+
if str(device) == "cpu":
|
137 |
+
model.float()
|
138 |
+
return model, _transform(model.visual.input_resolution)
|
139 |
+
|
140 |
+
# patch the device names
|
141 |
+
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
|
142 |
+
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
|
143 |
+
|
144 |
+
def patch_device(module):
|
145 |
+
try:
|
146 |
+
graphs = [module.graph] if hasattr(module, "graph") else []
|
147 |
+
except RuntimeError:
|
148 |
+
graphs = []
|
149 |
+
|
150 |
+
if hasattr(module, "forward1"):
|
151 |
+
graphs.append(module.forward1.graph)
|
152 |
+
|
153 |
+
for graph in graphs:
|
154 |
+
for node in graph.findAllNodes("prim::Constant"):
|
155 |
+
if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
|
156 |
+
node.copyAttributes(device_node)
|
157 |
+
|
158 |
+
model.apply(patch_device)
|
159 |
+
patch_device(model.encode_image)
|
160 |
+
patch_device(model.encode_text)
|
161 |
+
|
162 |
+
# patch dtype to float32 on CPU
|
163 |
+
if str(device) == "cpu":
|
164 |
+
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
|
165 |
+
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
|
166 |
+
float_node = float_input.node()
|
167 |
+
|
168 |
+
def patch_float(module):
|
169 |
+
try:
|
170 |
+
graphs = [module.graph] if hasattr(module, "graph") else []
|
171 |
+
except RuntimeError:
|
172 |
+
graphs = []
|
173 |
+
|
174 |
+
if hasattr(module, "forward1"):
|
175 |
+
graphs.append(module.forward1.graph)
|
176 |
+
|
177 |
+
for graph in graphs:
|
178 |
+
for node in graph.findAllNodes("aten::to"):
|
179 |
+
inputs = list(node.inputs())
|
180 |
+
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
|
181 |
+
if inputs[i].node()["value"] == 5:
|
182 |
+
inputs[i].node().copyAttributes(float_node)
|
183 |
+
|
184 |
+
model.apply(patch_float)
|
185 |
+
patch_float(model.encode_image)
|
186 |
+
patch_float(model.encode_text)
|
187 |
+
|
188 |
+
model.float()
|
189 |
+
|
190 |
+
return model, _transform(model.input_resolution.item())
|
191 |
+
|
192 |
+
|
193 |
+
def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor:
|
194 |
+
"""
|
195 |
+
Returns the tokenized representation of given input string(s)
|
196 |
+
|
197 |
+
Parameters
|
198 |
+
----------
|
199 |
+
texts : Union[str, List[str]]
|
200 |
+
An input string or a list of input strings to tokenize
|
201 |
+
|
202 |
+
context_length : int
|
203 |
+
The context length to use; all CLIP models use 77 as the context length
|
204 |
+
|
205 |
+
truncate: bool
|
206 |
+
Whether to truncate the text in case its encoding is longer than the context length
|
207 |
+
|
208 |
+
Returns
|
209 |
+
-------
|
210 |
+
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
|
211 |
+
"""
|
212 |
+
if isinstance(texts, str):
|
213 |
+
texts = [texts]
|
214 |
+
|
215 |
+
sot_token = _tokenizer.encoder["<|startoftext|>"]
|
216 |
+
eot_token = _tokenizer.encoder["<|endoftext|>"]
|
217 |
+
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
|
218 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
219 |
+
|
220 |
+
for i, tokens in enumerate(all_tokens):
|
221 |
+
if len(tokens) > context_length:
|
222 |
+
if truncate:
|
223 |
+
tokens = tokens[:context_length]
|
224 |
+
tokens[-1] = eot_token
|
225 |
+
else:
|
226 |
+
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
|
227 |
+
result[i, :len(tokens)] = torch.tensor(tokens)
|
228 |
+
|
229 |
+
return result
|
models/clip/model.py
ADDED
@@ -0,0 +1,577 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
from typing import Tuple, Union
|
3 |
+
|
4 |
+
import math
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torch import nn
|
9 |
+
|
10 |
+
|
11 |
+
class Bottleneck(nn.Module):
|
12 |
+
expansion = 4
|
13 |
+
|
14 |
+
def __init__(self, inplanes, planes, stride=1):
|
15 |
+
super().__init__()
|
16 |
+
|
17 |
+
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
18 |
+
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
19 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
20 |
+
|
21 |
+
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
22 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
23 |
+
|
24 |
+
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
25 |
+
|
26 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
27 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
28 |
+
|
29 |
+
self.relu = nn.ReLU(inplace=True)
|
30 |
+
self.downsample = None
|
31 |
+
self.stride = stride
|
32 |
+
|
33 |
+
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
34 |
+
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
35 |
+
self.downsample = nn.Sequential(OrderedDict([
|
36 |
+
("-1", nn.AvgPool2d(stride)),
|
37 |
+
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
38 |
+
("1", nn.BatchNorm2d(planes * self.expansion))
|
39 |
+
]))
|
40 |
+
|
41 |
+
def forward(self, x: torch.Tensor):
|
42 |
+
identity = x
|
43 |
+
|
44 |
+
out = self.relu(self.bn1(self.conv1(x)))
|
45 |
+
out = self.relu(self.bn2(self.conv2(out)))
|
46 |
+
out = self.avgpool(out)
|
47 |
+
out = self.bn3(self.conv3(out))
|
48 |
+
|
49 |
+
if self.downsample is not None:
|
50 |
+
identity = self.downsample(x)
|
51 |
+
|
52 |
+
out += identity
|
53 |
+
out = self.relu(out)
|
54 |
+
return out
|
55 |
+
|
56 |
+
|
57 |
+
class AttentionPool2d(nn.Module):
|
58 |
+
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
59 |
+
super().__init__()
|
60 |
+
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
|
61 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
62 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
63 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
64 |
+
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
65 |
+
self.num_heads = num_heads
|
66 |
+
|
67 |
+
def interpolate_pos_encoding(self, x, h0, w0):
|
68 |
+
assert w0 == h0, f'{self} only support square images!'
|
69 |
+
pos_embed = self.positional_embedding.unsqueeze(1).to(x.dtype)
|
70 |
+
npatch = x.shape[0] - 1
|
71 |
+
N = pos_embed.shape[0] - 1
|
72 |
+
if npatch == N:
|
73 |
+
return pos_embed
|
74 |
+
class_pos_embed = pos_embed[0]
|
75 |
+
patch_pos_embed = pos_embed[1:]
|
76 |
+
dim = x.shape[-1]
|
77 |
+
# we add a small number to avoid floating point error in the interpolation
|
78 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
79 |
+
w0, h0 = w0 + 0.1, h0 + 0.1
|
80 |
+
patch_pos_embed = nn.functional.interpolate(
|
81 |
+
patch_pos_embed.reshape(int(math.sqrt(N)), int(math.sqrt(N)), 1, dim).permute(2, 3, 0, 1),
|
82 |
+
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
|
83 |
+
mode='bicubic',
|
84 |
+
align_corners=False,
|
85 |
+
recompute_scale_factor=False
|
86 |
+
)
|
87 |
+
assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
|
88 |
+
patch_pos_embed = patch_pos_embed.permute(2, 3, 0, 1).view(-1, 1, dim)
|
89 |
+
return torch.cat((class_pos_embed.unsqueeze(1), patch_pos_embed), dim=0)
|
90 |
+
|
91 |
+
def forward(self, x):
|
92 |
+
B, C, H, W = x.shape
|
93 |
+
|
94 |
+
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
|
95 |
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
96 |
+
x = x + self.interpolate_pos_encoding(x, H, W) # (HW+1)NC
|
97 |
+
x, _ = F.multi_head_attention_forward(
|
98 |
+
query=x, key=x, value=x,
|
99 |
+
embed_dim_to_check=x.shape[-1],
|
100 |
+
num_heads=self.num_heads,
|
101 |
+
q_proj_weight=self.q_proj.weight,
|
102 |
+
k_proj_weight=self.k_proj.weight,
|
103 |
+
v_proj_weight=self.v_proj.weight,
|
104 |
+
in_proj_weight=None,
|
105 |
+
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
106 |
+
bias_k=None,
|
107 |
+
bias_v=None,
|
108 |
+
add_zero_attn=False,
|
109 |
+
dropout_p=0,
|
110 |
+
out_proj_weight=self.c_proj.weight,
|
111 |
+
out_proj_bias=self.c_proj.bias,
|
112 |
+
use_separate_proj_weight=True,
|
113 |
+
training=self.training,
|
114 |
+
need_weights=False
|
115 |
+
)
|
116 |
+
|
117 |
+
return x[0]
|
118 |
+
|
119 |
+
|
120 |
+
class ModifiedResNet(nn.Module):
|
121 |
+
"""
|
122 |
+
A ResNet class that is similar to torchvision's but contains the following changes:
|
123 |
+
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
124 |
+
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
125 |
+
- The final pooling layer is a QKV attention instead of an average pool
|
126 |
+
"""
|
127 |
+
|
128 |
+
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
|
129 |
+
super().__init__()
|
130 |
+
self.output_dim = output_dim
|
131 |
+
self.input_resolution = input_resolution
|
132 |
+
|
133 |
+
# the 3-layer stem
|
134 |
+
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
|
135 |
+
self.bn1 = nn.BatchNorm2d(width // 2)
|
136 |
+
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
|
137 |
+
self.bn2 = nn.BatchNorm2d(width // 2)
|
138 |
+
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
|
139 |
+
self.bn3 = nn.BatchNorm2d(width)
|
140 |
+
self.avgpool = nn.AvgPool2d(2)
|
141 |
+
self.relu = nn.ReLU(inplace=True)
|
142 |
+
|
143 |
+
# residual layers
|
144 |
+
self._inplanes = width # this is a *mutable* variable used during construction
|
145 |
+
self.layer1 = self._make_layer(width, layers[0])
|
146 |
+
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
147 |
+
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
148 |
+
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
149 |
+
|
150 |
+
embed_dim = width * 32 # the ResNet feature dimension
|
151 |
+
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
|
152 |
+
#self.gap = nn.AdaptiveAvgPool2d((1, 1))
|
153 |
+
|
154 |
+
def _make_layer(self, planes, blocks, stride=1):
|
155 |
+
layers = [Bottleneck(self._inplanes, planes, stride)]
|
156 |
+
|
157 |
+
self._inplanes = planes * Bottleneck.expansion
|
158 |
+
for _ in range(1, blocks):
|
159 |
+
layers.append(Bottleneck(self._inplanes, planes))
|
160 |
+
|
161 |
+
return nn.Sequential(*layers)
|
162 |
+
|
163 |
+
def forward(self, x):
|
164 |
+
def stem(x):
|
165 |
+
for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
|
166 |
+
x = self.relu(bn(conv(x)))
|
167 |
+
x = self.avgpool(x)
|
168 |
+
return x
|
169 |
+
|
170 |
+
x = x.type(self.conv1.weight.dtype)
|
171 |
+
x = stem(x)
|
172 |
+
x = self.layer1(x)
|
173 |
+
x = self.layer2(x)
|
174 |
+
x = self.layer3(x)
|
175 |
+
x = self.layer4(x)
|
176 |
+
x = self.attnpool(x)
|
177 |
+
#x = self.gap(x)
|
178 |
+
|
179 |
+
return x
|
180 |
+
|
181 |
+
|
182 |
+
class LayerNorm(nn.LayerNorm):
|
183 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
184 |
+
|
185 |
+
def forward(self, x: torch.Tensor):
|
186 |
+
orig_type = x.dtype
|
187 |
+
ret = super().forward(x.type(torch.float32))
|
188 |
+
return ret.type(orig_type)
|
189 |
+
|
190 |
+
|
191 |
+
class QuickGELU(nn.Module):
|
192 |
+
def forward(self, x: torch.Tensor):
|
193 |
+
return x * torch.sigmoid(1.702 * x)
|
194 |
+
|
195 |
+
|
196 |
+
class ResidualAttentionBlock(nn.Module):
|
197 |
+
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
|
198 |
+
super().__init__()
|
199 |
+
|
200 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
201 |
+
self.ln_1 = LayerNorm(d_model)
|
202 |
+
self.mlp = nn.Sequential(OrderedDict([
|
203 |
+
("c_fc", nn.Linear(d_model, d_model * 4)),
|
204 |
+
("gelu", QuickGELU()),
|
205 |
+
("c_proj", nn.Linear(d_model * 4, d_model))
|
206 |
+
]))
|
207 |
+
self.ln_2 = LayerNorm(d_model)
|
208 |
+
self.attn_mask = attn_mask
|
209 |
+
|
210 |
+
def attention(self, x: torch.Tensor):
|
211 |
+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
212 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
213 |
+
|
214 |
+
def forward(self, x: torch.Tensor):
|
215 |
+
x = x + self.attention(self.ln_1(x))
|
216 |
+
x = x + self.mlp(self.ln_2(x))
|
217 |
+
return x
|
218 |
+
|
219 |
+
|
220 |
+
class Transformer(nn.Module):
|
221 |
+
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
|
222 |
+
super().__init__()
|
223 |
+
self.width = width
|
224 |
+
self.layers = layers
|
225 |
+
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
|
226 |
+
|
227 |
+
def forward(self, x: torch.Tensor):
|
228 |
+
return self.resblocks(x)
|
229 |
+
|
230 |
+
|
231 |
+
class VisionTransformer(nn.Module):
|
232 |
+
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
|
233 |
+
super().__init__()
|
234 |
+
self.input_resolution = input_resolution
|
235 |
+
self.output_dim = output_dim
|
236 |
+
self.patch_size = patch_size
|
237 |
+
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
238 |
+
|
239 |
+
scale = width ** -0.5
|
240 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
241 |
+
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
|
242 |
+
self.ln_pre = LayerNorm(width)
|
243 |
+
|
244 |
+
self.transformer = Transformer(width, layers, heads)
|
245 |
+
|
246 |
+
self.ln_post = LayerNorm(width)
|
247 |
+
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
248 |
+
|
249 |
+
def interpolate_pos_encoding(self, x, h, w):
|
250 |
+
pos_embed = self.positional_embedding.unsqueeze(0).to(x.dtype)
|
251 |
+
npatch = x.shape[1] - 1
|
252 |
+
N = pos_embed.shape[1] - 1
|
253 |
+
if npatch == N and w == h:
|
254 |
+
return pos_embed
|
255 |
+
class_pos_embed = pos_embed[:, 0]
|
256 |
+
patch_pos_embed = pos_embed[:, 1:]
|
257 |
+
dim = x.shape[-1]
|
258 |
+
w0 = w // self.patch_size
|
259 |
+
h0 = h // self.patch_size
|
260 |
+
# we add a small number to avoid floating point error in the interpolation
|
261 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
262 |
+
w0, h0 = w0 + 0.1, h0 + 0.1
|
263 |
+
patch_pos_embed = nn.functional.interpolate(
|
264 |
+
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
|
265 |
+
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
|
266 |
+
mode='bicubic',
|
267 |
+
align_corners=False,
|
268 |
+
recompute_scale_factor=False
|
269 |
+
)
|
270 |
+
assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
|
271 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
272 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
273 |
+
|
274 |
+
def forward(self, x: torch.Tensor):
|
275 |
+
B, C, H, W = x.shape
|
276 |
+
|
277 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
278 |
+
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
279 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
280 |
+
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
281 |
+
x = x + self.interpolate_pos_encoding(x, H, W)
|
282 |
+
x = self.ln_pre(x)
|
283 |
+
|
284 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
285 |
+
x = self.transformer(x)
|
286 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
287 |
+
|
288 |
+
x = self.ln_post(x[:, 0, :])
|
289 |
+
|
290 |
+
if self.proj is not None:
|
291 |
+
x = x @ self.proj
|
292 |
+
|
293 |
+
return x
|
294 |
+
|
295 |
+
|
296 |
+
class VisionBackbone(nn.Module):
|
297 |
+
def __init__(self,
|
298 |
+
embed_dim: int,
|
299 |
+
# vision
|
300 |
+
image_resolution: int,
|
301 |
+
vision_layers: Union[Tuple[int, int, int, int], int],
|
302 |
+
vision_width: int,
|
303 |
+
vision_patch_size: int,
|
304 |
+
):
|
305 |
+
super().__init__()
|
306 |
+
|
307 |
+
if isinstance(vision_layers, (tuple, list)):
|
308 |
+
vision_heads = vision_width * 32 // 64
|
309 |
+
self.visual = ModifiedResNet(
|
310 |
+
layers=vision_layers,
|
311 |
+
output_dim=embed_dim,
|
312 |
+
heads=vision_heads,
|
313 |
+
input_resolution=image_resolution,
|
314 |
+
width=vision_width
|
315 |
+
)
|
316 |
+
else:
|
317 |
+
vision_heads = vision_width // 64
|
318 |
+
self.visual = VisionTransformer(
|
319 |
+
input_resolution=image_resolution,
|
320 |
+
patch_size=vision_patch_size,
|
321 |
+
width=vision_width,
|
322 |
+
layers=vision_layers,
|
323 |
+
heads=vision_heads,
|
324 |
+
output_dim=embed_dim
|
325 |
+
)
|
326 |
+
|
327 |
+
#self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
328 |
+
|
329 |
+
self.initialize_parameters()
|
330 |
+
|
331 |
+
def initialize_parameters(self):
|
332 |
+
if isinstance(self.visual, ModifiedResNet):
|
333 |
+
if self.visual.attnpool is not None:
|
334 |
+
std = self.visual.attnpool.c_proj.in_features ** -0.5
|
335 |
+
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
|
336 |
+
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
|
337 |
+
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
|
338 |
+
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
|
339 |
+
|
340 |
+
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
|
341 |
+
for name, param in resnet_block.named_parameters():
|
342 |
+
if name.endswith("bn3.weight"):
|
343 |
+
nn.init.zeros_(param)
|
344 |
+
|
345 |
+
@property
|
346 |
+
def dtype(self):
|
347 |
+
return self.visual.conv1.weight.dtype
|
348 |
+
|
349 |
+
def forward(self, image):
|
350 |
+
return self.visual(image.type(self.dtype))
|
351 |
+
|
352 |
+
|
353 |
+
class CLIP(nn.Module):
|
354 |
+
def __init__(self,
|
355 |
+
embed_dim: int,
|
356 |
+
# vision
|
357 |
+
image_resolution: int,
|
358 |
+
vision_layers: Union[Tuple[int, int, int, int], int],
|
359 |
+
vision_width: int,
|
360 |
+
vision_patch_size: int,
|
361 |
+
# text
|
362 |
+
context_length: int,
|
363 |
+
vocab_size: int,
|
364 |
+
transformer_width: int,
|
365 |
+
transformer_heads: int,
|
366 |
+
transformer_layers: int
|
367 |
+
):
|
368 |
+
super().__init__()
|
369 |
+
|
370 |
+
self.context_length = context_length
|
371 |
+
|
372 |
+
if isinstance(vision_layers, (tuple, list)):
|
373 |
+
vision_heads = vision_width * 32 // 64
|
374 |
+
self.visual = ModifiedResNet(
|
375 |
+
layers=vision_layers,
|
376 |
+
output_dim=embed_dim,
|
377 |
+
heads=vision_heads,
|
378 |
+
input_resolution=image_resolution,
|
379 |
+
width=vision_width
|
380 |
+
)
|
381 |
+
else:
|
382 |
+
vision_heads = vision_width // 64
|
383 |
+
self.visual = VisionTransformer(
|
384 |
+
input_resolution=image_resolution,
|
385 |
+
patch_size=vision_patch_size,
|
386 |
+
width=vision_width,
|
387 |
+
layers=vision_layers,
|
388 |
+
heads=vision_heads,
|
389 |
+
output_dim=embed_dim
|
390 |
+
)
|
391 |
+
|
392 |
+
self.transformer = Transformer(
|
393 |
+
width=transformer_width,
|
394 |
+
layers=transformer_layers,
|
395 |
+
heads=transformer_heads,
|
396 |
+
attn_mask=self.build_attention_mask()
|
397 |
+
)
|
398 |
+
|
399 |
+
self.vocab_size = vocab_size
|
400 |
+
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
|
401 |
+
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
|
402 |
+
self.ln_final = LayerNorm(transformer_width)
|
403 |
+
|
404 |
+
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
|
405 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
406 |
+
|
407 |
+
self.initialize_parameters()
|
408 |
+
|
409 |
+
def initialize_parameters(self):
|
410 |
+
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
411 |
+
nn.init.normal_(self.positional_embedding, std=0.01)
|
412 |
+
|
413 |
+
if isinstance(self.visual, ModifiedResNet):
|
414 |
+
if self.visual.attnpool is not None:
|
415 |
+
std = self.visual.attnpool.c_proj.in_features ** -0.5
|
416 |
+
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
|
417 |
+
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
|
418 |
+
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
|
419 |
+
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
|
420 |
+
|
421 |
+
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
|
422 |
+
for name, param in resnet_block.named_parameters():
|
423 |
+
if name.endswith("bn3.weight"):
|
424 |
+
nn.init.zeros_(param)
|
425 |
+
|
426 |
+
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
427 |
+
attn_std = self.transformer.width ** -0.5
|
428 |
+
fc_std = (2 * self.transformer.width) ** -0.5
|
429 |
+
for block in self.transformer.resblocks:
|
430 |
+
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
431 |
+
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
432 |
+
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
433 |
+
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
434 |
+
|
435 |
+
if self.text_projection is not None:
|
436 |
+
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
437 |
+
|
438 |
+
def build_attention_mask(self):
|
439 |
+
# lazily create causal attention mask, with full attention between the vision tokens
|
440 |
+
# pytorch uses additive attention mask; fill with -inf
|
441 |
+
mask = torch.empty(self.context_length, self.context_length)
|
442 |
+
mask.fill_(float("-inf"))
|
443 |
+
mask.triu_(1) # zero out the lower diagonal
|
444 |
+
return mask
|
445 |
+
|
446 |
+
@property
|
447 |
+
def dtype(self):
|
448 |
+
return self.visual.conv1.weight.dtype
|
449 |
+
|
450 |
+
def encode_image(self, image):
|
451 |
+
return self.visual(image.type(self.dtype))
|
452 |
+
|
453 |
+
def encode_text(self, text):
|
454 |
+
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
|
455 |
+
|
456 |
+
x = x + self.positional_embedding.type(self.dtype)
|
457 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
458 |
+
x = self.transformer(x)
|
459 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
460 |
+
x = self.ln_final(x).type(self.dtype)
|
461 |
+
|
462 |
+
# x.shape = [batch_size, n_ctx, transformer.width]
|
463 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
464 |
+
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
465 |
+
|
466 |
+
return x
|
467 |
+
|
468 |
+
def forward(self, image, text):
|
469 |
+
image_features = self.encode_image(image)
|
470 |
+
text_features = self.encode_text(text)
|
471 |
+
|
472 |
+
# normalized features
|
473 |
+
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
474 |
+
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
475 |
+
|
476 |
+
# cosine similarity as logits
|
477 |
+
logit_scale = self.logit_scale.exp()
|
478 |
+
logits_per_image = logit_scale * image_features @ text_features.t()
|
479 |
+
logits_per_text = logits_per_image.t()
|
480 |
+
|
481 |
+
# shape = [global_batch_size, global_batch_size]
|
482 |
+
return logits_per_image, logits_per_text
|
483 |
+
|
484 |
+
|
485 |
+
def convert_weights(model: nn.Module):
|
486 |
+
"""Convert applicable model parameters to fp16"""
|
487 |
+
|
488 |
+
def _convert_weights_to_fp16(l):
|
489 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
490 |
+
l.weight.data = l.weight.data.half()
|
491 |
+
if l.bias is not None:
|
492 |
+
l.bias.data = l.bias.data.half()
|
493 |
+
|
494 |
+
if isinstance(l, nn.MultiheadAttention):
|
495 |
+
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
496 |
+
tensor = getattr(l, attr)
|
497 |
+
if tensor is not None:
|
498 |
+
tensor.data = tensor.data.half()
|
499 |
+
|
500 |
+
for name in ["text_projection", "proj"]:
|
501 |
+
if hasattr(l, name):
|
502 |
+
attr = getattr(l, name)
|
503 |
+
if attr is not None:
|
504 |
+
attr.data = attr.data.half()
|
505 |
+
|
506 |
+
model.apply(_convert_weights_to_fp16)
|
507 |
+
|
508 |
+
|
509 |
+
def build_model(state_dict: dict):
|
510 |
+
vit = "visual.proj" in state_dict
|
511 |
+
|
512 |
+
if vit:
|
513 |
+
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
514 |
+
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
515 |
+
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
516 |
+
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
517 |
+
image_resolution = vision_patch_size * grid_size
|
518 |
+
else:
|
519 |
+
counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
|
520 |
+
vision_layers = tuple(counts)
|
521 |
+
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
522 |
+
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
523 |
+
vision_patch_size = None
|
524 |
+
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
|
525 |
+
image_resolution = output_width * 32
|
526 |
+
|
527 |
+
embed_dim = state_dict["text_projection"].shape[1]
|
528 |
+
context_length = state_dict["positional_embedding"].shape[0]
|
529 |
+
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
530 |
+
transformer_width = state_dict["ln_final.weight"].shape[0]
|
531 |
+
transformer_heads = transformer_width // 64
|
532 |
+
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
|
533 |
+
|
534 |
+
model = CLIP(
|
535 |
+
embed_dim,
|
536 |
+
image_resolution, vision_layers, vision_width, vision_patch_size,
|
537 |
+
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
|
538 |
+
)
|
539 |
+
|
540 |
+
for key in ["input_resolution", "context_length", "vocab_size"]:
|
541 |
+
if key in state_dict:
|
542 |
+
del state_dict[key]
|
543 |
+
|
544 |
+
convert_weights(model)
|
545 |
+
model.load_state_dict(state_dict)
|
546 |
+
return model.eval()
|
547 |
+
|
548 |
+
|
549 |
+
def build_vision_model(state_dict: dict):
|
550 |
+
vit = "visual.proj" in state_dict
|
551 |
+
|
552 |
+
if vit:
|
553 |
+
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
554 |
+
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
555 |
+
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
556 |
+
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
557 |
+
image_resolution = vision_patch_size * grid_size
|
558 |
+
else:
|
559 |
+
counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
|
560 |
+
vision_layers = tuple(counts)
|
561 |
+
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
562 |
+
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
563 |
+
vision_patch_size = None
|
564 |
+
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
|
565 |
+
image_resolution = output_width * 32
|
566 |
+
|
567 |
+
embed_dim = state_dict["text_projection"].shape[1]
|
568 |
+
|
569 |
+
model = VisionBackbone(
|
570 |
+
embed_dim,
|
571 |
+
image_resolution, vision_layers, vision_width, vision_patch_size,
|
572 |
+
)
|
573 |
+
|
574 |
+
convert_weights(model)
|
575 |
+
msg = model.load_state_dict(state_dict, strict=False)
|
576 |
+
print(f'clip.build_vision_model: pretrained weights loaded with message: {msg}')
|
577 |
+
return model.eval()
|
models/clip/simple_tokenizer.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gzip
|
2 |
+
import html
|
3 |
+
import os
|
4 |
+
from functools import lru_cache
|
5 |
+
|
6 |
+
import ftfy
|
7 |
+
import regex as re
|
8 |
+
|
9 |
+
|
10 |
+
@lru_cache()
|
11 |
+
def default_bpe():
|
12 |
+
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
|
13 |
+
|
14 |
+
|
15 |
+
@lru_cache()
|
16 |
+
def bytes_to_unicode():
|
17 |
+
"""
|
18 |
+
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
19 |
+
The reversible bpe codes work on unicode strings.
|
20 |
+
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
21 |
+
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
22 |
+
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
23 |
+
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
24 |
+
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
25 |
+
"""
|
26 |
+
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
27 |
+
cs = bs[:]
|
28 |
+
n = 0
|
29 |
+
for b in range(2**8):
|
30 |
+
if b not in bs:
|
31 |
+
bs.append(b)
|
32 |
+
cs.append(2**8+n)
|
33 |
+
n += 1
|
34 |
+
cs = [chr(n) for n in cs]
|
35 |
+
return dict(zip(bs, cs))
|
36 |
+
|
37 |
+
|
38 |
+
def get_pairs(word):
|
39 |
+
"""Return set of symbol pairs in a word.
|
40 |
+
Word is represented as tuple of symbols (symbols being variable-length strings).
|
41 |
+
"""
|
42 |
+
pairs = set()
|
43 |
+
prev_char = word[0]
|
44 |
+
for char in word[1:]:
|
45 |
+
pairs.add((prev_char, char))
|
46 |
+
prev_char = char
|
47 |
+
return pairs
|
48 |
+
|
49 |
+
|
50 |
+
def basic_clean(text):
|
51 |
+
text = ftfy.fix_text(text)
|
52 |
+
text = html.unescape(html.unescape(text))
|
53 |
+
return text.strip()
|
54 |
+
|
55 |
+
|
56 |
+
def whitespace_clean(text):
|
57 |
+
text = re.sub(r'\s+', ' ', text)
|
58 |
+
text = text.strip()
|
59 |
+
return text
|
60 |
+
|
61 |
+
|
62 |
+
class SimpleTokenizer(object):
|
63 |
+
def __init__(self, bpe_path: str = default_bpe()):
|
64 |
+
self.byte_encoder = bytes_to_unicode()
|
65 |
+
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
66 |
+
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
|
67 |
+
merges = merges[1:49152-256-2+1]
|
68 |
+
merges = [tuple(merge.split()) for merge in merges]
|
69 |
+
vocab = list(bytes_to_unicode().values())
|
70 |
+
vocab = vocab + [v+'</w>' for v in vocab]
|
71 |
+
for merge in merges:
|
72 |
+
vocab.append(''.join(merge))
|
73 |
+
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
|
74 |
+
self.encoder = dict(zip(vocab, range(len(vocab))))
|
75 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
76 |
+
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
77 |
+
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
|
78 |
+
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
|
79 |
+
|
80 |
+
def bpe(self, token):
|
81 |
+
if token in self.cache:
|
82 |
+
return self.cache[token]
|
83 |
+
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
84 |
+
pairs = get_pairs(word)
|
85 |
+
|
86 |
+
if not pairs:
|
87 |
+
return token+'</w>'
|
88 |
+
|
89 |
+
while True:
|
90 |
+
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
91 |
+
if bigram not in self.bpe_ranks:
|
92 |
+
break
|
93 |
+
first, second = bigram
|
94 |
+
new_word = []
|
95 |
+
i = 0
|
96 |
+
while i < len(word):
|
97 |
+
try:
|
98 |
+
j = word.index(first, i)
|
99 |
+
new_word.extend(word[i:j])
|
100 |
+
i = j
|
101 |
+
except:
|
102 |
+
new_word.extend(word[i:])
|
103 |
+
break
|
104 |
+
|
105 |
+
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
106 |
+
new_word.append(first+second)
|
107 |
+
i += 2
|
108 |
+
else:
|
109 |
+
new_word.append(word[i])
|
110 |
+
i += 1
|
111 |
+
new_word = tuple(new_word)
|
112 |
+
word = new_word
|
113 |
+
if len(word) == 1:
|
114 |
+
break
|
115 |
+
else:
|
116 |
+
pairs = get_pairs(word)
|
117 |
+
word = ' '.join(word)
|
118 |
+
self.cache[token] = word
|
119 |
+
return word
|
120 |
+
|
121 |
+
def encode(self, text):
|
122 |
+
bpe_tokens = []
|
123 |
+
text = whitespace_clean(basic_clean(text)).lower()
|
124 |
+
for token in re.findall(self.pat, text):
|
125 |
+
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
126 |
+
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
127 |
+
return bpe_tokens
|
128 |
+
|
129 |
+
def decode(self, tokens):
|
130 |
+
text = ''.join([self.decoder[token] for token in tokens])
|
131 |
+
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
132 |
+
return text
|
models/deploy.py
ADDED
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import torch.distributed as dist
|
6 |
+
from copy import deepcopy
|
7 |
+
from tqdm import tqdm
|
8 |
+
from timm.utils import accuracy
|
9 |
+
from .protonet import ProtoNet
|
10 |
+
from .utils import trunc_normal_, DiffAugment
|
11 |
+
|
12 |
+
|
13 |
+
def is_dist_avail_and_initialized():
|
14 |
+
if not dist.is_available():
|
15 |
+
return False
|
16 |
+
if not dist.is_initialized():
|
17 |
+
return False
|
18 |
+
return True
|
19 |
+
|
20 |
+
|
21 |
+
def get_rank():
|
22 |
+
if not is_dist_avail_and_initialized():
|
23 |
+
return 0
|
24 |
+
return dist.get_rank()
|
25 |
+
|
26 |
+
|
27 |
+
def is_main_process():
|
28 |
+
return get_rank() == 0
|
29 |
+
|
30 |
+
|
31 |
+
@torch.jit.script
|
32 |
+
def entropy_loss(x):
|
33 |
+
return torch.sum(-F.softmax(x, 1) * F.log_softmax(x, 1), 1).mean()
|
34 |
+
|
35 |
+
|
36 |
+
def unique_indices(x):
|
37 |
+
"""
|
38 |
+
Ref: https://github.com/rusty1s/pytorch_unique
|
39 |
+
"""
|
40 |
+
unique, inverse = torch.unique(x, sorted=True, return_inverse=True)
|
41 |
+
perm = torch.arange(inverse.size(0), dtype=inverse.dtype, device=inverse.device)
|
42 |
+
inverse, perm = inverse.flip([0]), perm.flip([0])
|
43 |
+
perm = inverse.new_empty(unique.size(0)).scatter_(0, inverse, perm)
|
44 |
+
return unique, perm
|
45 |
+
|
46 |
+
|
47 |
+
class ProtoNet_Auto_Finetune(ProtoNet):
|
48 |
+
def __init__(self, backbone, num_iters=50, aug_prob=0.9,
|
49 |
+
aug_types=['color', 'translation'], lr_lst=[0.01, 0.001, 0.0001]):
|
50 |
+
super().__init__(backbone)
|
51 |
+
self.num_iters = num_iters
|
52 |
+
self.lr_lst = lr_lst
|
53 |
+
self.aug_types = aug_types
|
54 |
+
self.aug_prob = aug_prob
|
55 |
+
|
56 |
+
state_dict = backbone.state_dict()
|
57 |
+
self.backbone_state = deepcopy(state_dict)
|
58 |
+
|
59 |
+
def forward(self, supp_x, supp_y, qry_x):
|
60 |
+
"""
|
61 |
+
supp_x.shape = [B, nSupp, C, H, W]
|
62 |
+
supp_y.shape = [B, nSupp]
|
63 |
+
qry_x.shape = [B, nQry, C, H, W]
|
64 |
+
"""
|
65 |
+
B, nSupp, C, H, W = supp_x.shape
|
66 |
+
num_classes = supp_y.max() + 1 # NOTE: assume B==1
|
67 |
+
device = qry_x.device
|
68 |
+
|
69 |
+
criterion = nn.CrossEntropyLoss()
|
70 |
+
supp_x = supp_x.view(-1, C, H, W)
|
71 |
+
qry_x = qry_x.view(-1, C, H, W)
|
72 |
+
supp_y_1hot = F.one_hot(supp_y, num_classes).transpose(1, 2) # B, nC, nSupp
|
73 |
+
supp_y = supp_y.view(-1)
|
74 |
+
|
75 |
+
def single_step(z, mode=True, x=None, y=None, y_1hot=None):
|
76 |
+
'''
|
77 |
+
z = Aug(supp_x) or qry_x
|
78 |
+
global vars: supp_x, supp_y, supp_y_1hot
|
79 |
+
'''
|
80 |
+
with torch.set_grad_enabled(mode):
|
81 |
+
# recalculate prototypes from supp_x with updated backbone
|
82 |
+
proto_f = self.backbone.forward(x).unsqueeze(0)
|
83 |
+
|
84 |
+
if y_1hot is None:
|
85 |
+
prototypes = proto_f
|
86 |
+
else:
|
87 |
+
prototypes = torch.bmm(y_1hot.float(), proto_f) # B, nC, d
|
88 |
+
prototypes = prototypes / y_1hot.sum(dim=2, keepdim=True) # NOTE: may div 0
|
89 |
+
|
90 |
+
# compute feature for z
|
91 |
+
feat = self.backbone.forward(z)
|
92 |
+
feat = feat.view(B, z.shape[0], -1) # B, nQry, d
|
93 |
+
|
94 |
+
# classification
|
95 |
+
logits = self.cos_classifier(prototypes, feat) # B, nQry, nC
|
96 |
+
loss = None
|
97 |
+
|
98 |
+
if mode: # if enable grad, compute loss
|
99 |
+
loss = criterion(logits.view(len(y), -1), y)
|
100 |
+
|
101 |
+
return logits, loss
|
102 |
+
|
103 |
+
# load trained weights
|
104 |
+
self.backbone.load_state_dict(self.backbone_state, strict=True)
|
105 |
+
|
106 |
+
#zz = DiffAugment(supp_x, ["color", "offset", "offset_h", "offset_v", "translation", "cutout"], 1., detach=True)
|
107 |
+
proto_y, proto_i = unique_indices(supp_y)
|
108 |
+
proto_x = supp_x[proto_i]
|
109 |
+
zz_i = np.setdiff1d(range(len(supp_x)), proto_i.cpu().numpy())
|
110 |
+
zz_x = supp_x[zz_i]
|
111 |
+
zz_y = supp_y[zz_i]
|
112 |
+
|
113 |
+
best_lr = 0
|
114 |
+
max_acc1 = 0
|
115 |
+
|
116 |
+
if len(zz_y) > 0:
|
117 |
+
# eval non-finetuned weights (lr=0)
|
118 |
+
logits, _ = single_step(zz_x, False, x=proto_x)
|
119 |
+
max_acc1 = accuracy(logits.view(len(zz_y), -1), zz_y, topk=(1,))[0]
|
120 |
+
print(f'## *lr = 0: acc1 = {max_acc1}\n')
|
121 |
+
|
122 |
+
for lr in self.lr_lst:
|
123 |
+
# create optimizer
|
124 |
+
opt = torch.optim.Adam(self.backbone.parameters(),
|
125 |
+
lr=lr,
|
126 |
+
betas=(0.9, 0.999),
|
127 |
+
weight_decay=0.)
|
128 |
+
|
129 |
+
# main loop
|
130 |
+
_num_iters = 50
|
131 |
+
pbar = tqdm(range(_num_iters)) if is_main_process() else range(_num_iters)
|
132 |
+
for i in pbar:
|
133 |
+
opt.zero_grad()
|
134 |
+
z = DiffAugment(proto_x, self.aug_types, self.aug_prob, detach=True)
|
135 |
+
_, loss = single_step(z, True, x=proto_x, y=proto_y)
|
136 |
+
loss.backward()
|
137 |
+
opt.step()
|
138 |
+
if is_main_process():
|
139 |
+
pbar.set_description(f' << lr = {lr}: loss = {loss.item()}')
|
140 |
+
|
141 |
+
logits, _ = single_step(zz_x, False, x=proto_x)
|
142 |
+
acc1 = accuracy(logits.view(len(zz_y), -1), zz_y, topk=(1,))[0]
|
143 |
+
print(f'## *lr = {lr}: acc1 = {acc1}\n')
|
144 |
+
|
145 |
+
if acc1 > max_acc1:
|
146 |
+
max_acc1 = acc1
|
147 |
+
best_lr = lr
|
148 |
+
|
149 |
+
# reset backbone state
|
150 |
+
self.backbone.load_state_dict(self.backbone_state, strict=True)
|
151 |
+
|
152 |
+
print(f'***Best lr = {best_lr} with acc1 = {max_acc1}.\nStart final loop...\n')
|
153 |
+
|
154 |
+
# create optimizer
|
155 |
+
opt = torch.optim.Adam(self.backbone.parameters(),
|
156 |
+
lr=best_lr,
|
157 |
+
betas=(0.9, 0.999),
|
158 |
+
weight_decay=0.)
|
159 |
+
|
160 |
+
# main loop
|
161 |
+
pbar = tqdm(range(self.num_iters)) if is_main_process() else range(self.num_iters)
|
162 |
+
for i in pbar:
|
163 |
+
opt.zero_grad()
|
164 |
+
z = DiffAugment(supp_x, self.aug_types, self.aug_prob, detach=True)
|
165 |
+
_, loss = single_step(z, True, x=supp_x, y=supp_y, y_1hot=supp_y_1hot)
|
166 |
+
loss.backward()
|
167 |
+
opt.step()
|
168 |
+
if is_main_process():
|
169 |
+
pbar.set_description(f' >> lr = {best_lr}: loss = {loss.item()}')
|
170 |
+
|
171 |
+
logits, _ = single_step(qry_x, False, x=supp_x, y_1hot=supp_y_1hot) # supp_x has to pair with y_1hot
|
172 |
+
|
173 |
+
return logits
|
174 |
+
|
175 |
+
|
176 |
+
class ProtoNet_Finetune(ProtoNet):
|
177 |
+
def __init__(self, backbone, num_iters=50, lr=5e-2, aug_prob=0.9,
|
178 |
+
aug_types=['color', 'translation']):
|
179 |
+
super().__init__(backbone)
|
180 |
+
self.num_iters = num_iters
|
181 |
+
self.lr = lr
|
182 |
+
self.aug_types = aug_types
|
183 |
+
self.aug_prob = aug_prob
|
184 |
+
|
185 |
+
def load_state_dict(self, state_dict, strict=True):
|
186 |
+
super().load_state_dict(state_dict, strict)
|
187 |
+
|
188 |
+
state_dict = self.backbone.state_dict()
|
189 |
+
self.backbone_state = deepcopy(state_dict)
|
190 |
+
|
191 |
+
def forward(self, supp_x, supp_y, x):
|
192 |
+
"""
|
193 |
+
supp_x.shape = [B, nSupp, C, H, W]
|
194 |
+
supp_y.shape = [B, nSupp]
|
195 |
+
x.shape = [B, nQry, C, H, W]
|
196 |
+
"""
|
197 |
+
# reset backbone state
|
198 |
+
self.backbone.load_state_dict(self.backbone_state, strict=True)
|
199 |
+
|
200 |
+
if self.lr == 0:
|
201 |
+
return super().forward(supp_x, supp_y, x)
|
202 |
+
|
203 |
+
B, nSupp, C, H, W = supp_x.shape
|
204 |
+
num_classes = supp_y.max() + 1 # NOTE: assume B==1
|
205 |
+
device = x.device
|
206 |
+
|
207 |
+
criterion = nn.CrossEntropyLoss()
|
208 |
+
supp_x = supp_x.view(-1, C, H, W)
|
209 |
+
x = x.view(-1, C, H, W)
|
210 |
+
supp_y_1hot = F.one_hot(supp_y, num_classes).transpose(1, 2) # B, nC, nSupp
|
211 |
+
supp_y = supp_y.view(-1)
|
212 |
+
|
213 |
+
# create optimizer
|
214 |
+
opt = torch.optim.Adam(self.backbone.parameters(),
|
215 |
+
lr=self.lr,
|
216 |
+
betas=(0.9, 0.999),
|
217 |
+
weight_decay=0.)
|
218 |
+
|
219 |
+
def single_step(z, mode=True):
|
220 |
+
'''
|
221 |
+
z = Aug(supp_x) or x
|
222 |
+
'''
|
223 |
+
with torch.set_grad_enabled(mode):
|
224 |
+
# recalculate prototypes from supp_x with updated backbone
|
225 |
+
supp_f = self.backbone.forward(supp_x)
|
226 |
+
supp_f = supp_f.view(B, nSupp, -1)
|
227 |
+
prototypes = torch.bmm(supp_y_1hot.float(), supp_f) # B, nC, d
|
228 |
+
prototypes = prototypes / supp_y_1hot.sum(dim=2, keepdim=True) # NOTE: may div 0
|
229 |
+
|
230 |
+
# compute feature for z
|
231 |
+
feat = self.backbone.forward(z)
|
232 |
+
feat = feat.view(B, z.shape[0], -1) # B, nQry, d
|
233 |
+
|
234 |
+
# classification
|
235 |
+
logits = self.cos_classifier(prototypes, feat) # B, nQry, nC
|
236 |
+
loss = None
|
237 |
+
|
238 |
+
if mode: # if enable grad, compute loss
|
239 |
+
loss = criterion(logits.view(B*nSupp, -1), supp_y)
|
240 |
+
|
241 |
+
return logits, loss
|
242 |
+
|
243 |
+
# main loop
|
244 |
+
pbar = tqdm(range(self.num_iters)) if is_main_process() else range(self.num_iters)
|
245 |
+
for i in pbar:
|
246 |
+
opt.zero_grad()
|
247 |
+
z = DiffAugment(supp_x, self.aug_types, self.aug_prob, detach=True)
|
248 |
+
_, loss = single_step(z, True)
|
249 |
+
loss.backward()
|
250 |
+
opt.step()
|
251 |
+
if is_main_process():
|
252 |
+
pbar.set_description(f'lr{self.lr}, nSupp{nSupp}, nQry{x.shape[0]}: loss = {loss.item()}')
|
253 |
+
|
254 |
+
logits, _ = single_step(x, False)
|
255 |
+
return logits
|
256 |
+
|
257 |
+
|
258 |
+
class ProtoNet_AdaTok(ProtoNet):
|
259 |
+
def __init__(self, backbone, num_adapters=1, num_iters=50, lr=5e-2, momentum=0.9, weight_decay=0.):
|
260 |
+
super().__init__(backbone)
|
261 |
+
self.num_adapters = num_adapters
|
262 |
+
self.num_iters = num_iters
|
263 |
+
self.lr = lr
|
264 |
+
self.momentum = momentum
|
265 |
+
self.weight_decay = weight_decay
|
266 |
+
|
267 |
+
def forward(self, supp_x, supp_y, x):
|
268 |
+
"""
|
269 |
+
supp_x.shape = [B, nSupp, C, H, W]
|
270 |
+
supp_y.shape = [B, nSupp]
|
271 |
+
x.shape = [B, nQry, C, H, W]
|
272 |
+
"""
|
273 |
+
B, nSupp, C, H, W = supp_x.shape
|
274 |
+
nQry = x.shape[1]
|
275 |
+
num_classes = supp_y.max() + 1 # NOTE: assume B==1
|
276 |
+
device = x.device
|
277 |
+
|
278 |
+
criterion = nn.CrossEntropyLoss()
|
279 |
+
supp_x = supp_x.view(-1, C, H, W)
|
280 |
+
x = x.view(-1, C, H, W)
|
281 |
+
supp_y_1hot = F.one_hot(supp_y, num_classes).transpose(1, 2) # B, nC, nSupp
|
282 |
+
supp_y = supp_y.view(-1)
|
283 |
+
|
284 |
+
# prepare adapter tokens
|
285 |
+
ada_tokens = torch.zeros(1, self.num_adapters, self.backbone.embed_dim, device=device)
|
286 |
+
trunc_normal_(ada_tokens, std=.02)
|
287 |
+
ada_tokens = ada_tokens.detach().requires_grad_()
|
288 |
+
#optimizer = torch.optim.SGD([ada_tokens],
|
289 |
+
optimizer = torch.optim.Adadelta([ada_tokens],
|
290 |
+
lr=self.lr,
|
291 |
+
#momentum=self.momentum,
|
292 |
+
weight_decay=self.weight_decay)
|
293 |
+
|
294 |
+
def single_step(mode=True):
|
295 |
+
with torch.set_grad_enabled(mode):
|
296 |
+
supp_f = self.backbone.forward(supp_x, ada_tokens)
|
297 |
+
supp_f = supp_f.view(B, nSupp, -1)
|
298 |
+
|
299 |
+
# B, nC, nSupp x B, nSupp, d = B, nC, d
|
300 |
+
prototypes = torch.bmm(supp_y_1hot.float(), supp_f)
|
301 |
+
prototypes = prototypes / supp_y_1hot.sum(dim=2, keepdim=True) # NOTE: may div 0
|
302 |
+
|
303 |
+
if mode == False: # no grad
|
304 |
+
feat = self.backbone.forward(x, ada_tokens)
|
305 |
+
feat = feat.view(B, nQry, -1) # B, nQry, d
|
306 |
+
|
307 |
+
logits = self.cos_classifier(prototypes, feat) # B, nQry, nC
|
308 |
+
loss = None
|
309 |
+
else:
|
310 |
+
with torch.enable_grad():
|
311 |
+
logits = self.cos_classifier(prototypes, supp_f) # B, nQry, nC
|
312 |
+
loss = criterion(logits.view(B*nSupp, -1), supp_y)
|
313 |
+
|
314 |
+
return logits, loss
|
315 |
+
|
316 |
+
pbar = tqdm(range(self.num_iters)) if is_main_process() else range(self.num_iters)
|
317 |
+
for i in pbar:
|
318 |
+
optimizer.zero_grad()
|
319 |
+
_, loss = single_step(True)
|
320 |
+
loss.backward()
|
321 |
+
optimizer.step()
|
322 |
+
if is_main_process():
|
323 |
+
pbar.set_description(f'loss = {loss.item()}')
|
324 |
+
|
325 |
+
logits, _ = single_step(False)
|
326 |
+
return logits
|
327 |
+
|
328 |
+
|
329 |
+
class ProtoNet_AdaTok_EntMin(ProtoNet):
|
330 |
+
def __init__(self, backbone, num_adapters=1, num_iters=50, lr=5e-3, momentum=0.9, weight_decay=0.):
|
331 |
+
super().__init__(backbone)
|
332 |
+
self.num_adapters = num_adapters
|
333 |
+
self.num_iters = num_iters
|
334 |
+
self.lr = lr
|
335 |
+
self.momentum = momentum
|
336 |
+
self.weight_decay = weight_decay
|
337 |
+
|
338 |
+
def forward(self, supp_x, supp_y, x):
|
339 |
+
"""
|
340 |
+
supp_x.shape = [B, nSupp, C, H, W]
|
341 |
+
supp_y.shape = [B, nSupp]
|
342 |
+
x.shape = [B, nQry, C, H, W]
|
343 |
+
"""
|
344 |
+
B, nSupp, C, H, W = supp_x.shape
|
345 |
+
num_classes = supp_y.max() + 1 # NOTE: assume B==1
|
346 |
+
device = x.device
|
347 |
+
|
348 |
+
criterion = entropy_loss
|
349 |
+
supp_x = supp_x.view(-1, C, H, W)
|
350 |
+
x = x.view(-1, C, H, W)
|
351 |
+
supp_y_1hot = F.one_hot(supp_y, num_classes).transpose(1, 2) # B, nC, nSupp
|
352 |
+
|
353 |
+
# adapter tokens
|
354 |
+
ada_tokens = torch.zeros(1, self.num_adapters, self.backbone.embed_dim, device=device)
|
355 |
+
trunc_normal_(ada_tokens, std=.02)
|
356 |
+
ada_tokens = ada_tokens.detach().requires_grad_()
|
357 |
+
optimizer = torch.optim.SGD([ada_tokens],
|
358 |
+
lr=self.lr,
|
359 |
+
momentum=self.momentum,
|
360 |
+
weight_decay=self.weight_decay)
|
361 |
+
|
362 |
+
def single_step(mode=True):
|
363 |
+
with torch.set_grad_enabled(mode):
|
364 |
+
supp_f = self.backbone.forward(supp_x, ada_tokens)
|
365 |
+
supp_f = supp_f.view(B, nSupp, -1)
|
366 |
+
|
367 |
+
# B, nC, nSupp x B, nSupp, d = B, nC, d
|
368 |
+
prototypes = torch.bmm(supp_y_1hot.float(), supp_f)
|
369 |
+
prototypes = prototypes / supp_y_1hot.sum(dim=2, keepdim=True) # NOTE: may div 0
|
370 |
+
|
371 |
+
feat = self.backbone.forward(x, ada_tokens)
|
372 |
+
feat = feat.view(B, x.shape[1], -1) # B, nQry, d
|
373 |
+
|
374 |
+
logits = self.cos_classifier(prototypes, feat) # B, nQry, nC
|
375 |
+
loss = criterion(logits.view(-1, num_classes))
|
376 |
+
|
377 |
+
return logits, loss
|
378 |
+
|
379 |
+
pbar = tqdm(range(self.num_iters)) if is_main_process() else range(self.num_iters)
|
380 |
+
for i in pbar:
|
381 |
+
optimizer.zero_grad()
|
382 |
+
_, loss = single_step(True)
|
383 |
+
loss.backward()
|
384 |
+
optimizer.step()
|
385 |
+
if is_main_process():
|
386 |
+
pbar.set_description(f'loss = {loss.item()}')
|
387 |
+
|
388 |
+
logits, _ = single_step(False)
|
389 |
+
return logits
|
models/protonet.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class ProtoNet(nn.Module):
|
7 |
+
def __init__(self, backbone):
|
8 |
+
super().__init__()
|
9 |
+
|
10 |
+
# bias & scale of cosine classifier
|
11 |
+
self.bias = nn.Parameter(torch.FloatTensor(1).fill_(0), requires_grad=True)
|
12 |
+
self.scale_cls = nn.Parameter(torch.FloatTensor(1).fill_(10), requires_grad=True)
|
13 |
+
|
14 |
+
# backbone
|
15 |
+
self.backbone = backbone
|
16 |
+
|
17 |
+
def cos_classifier(self, w, f):
|
18 |
+
"""
|
19 |
+
w.shape = B, nC, d
|
20 |
+
f.shape = B, M, d
|
21 |
+
"""
|
22 |
+
f = F.normalize(f, p=2, dim=f.dim()-1, eps=1e-12)
|
23 |
+
w = F.normalize(w, p=2, dim=w.dim()-1, eps=1e-12)
|
24 |
+
|
25 |
+
cls_scores = f @ w.transpose(1, 2) # B, M, nC
|
26 |
+
cls_scores = self.scale_cls * (cls_scores + self.bias)
|
27 |
+
return cls_scores
|
28 |
+
|
29 |
+
def forward(self, supp_x, supp_y, x):
|
30 |
+
"""
|
31 |
+
supp_x.shape = [B, nSupp, C, H, W]
|
32 |
+
supp_y.shape = [B, nSupp]
|
33 |
+
x.shape = [B, nQry, C, H, W]
|
34 |
+
"""
|
35 |
+
num_classes = supp_y.max() + 1 # NOTE: assume B==1
|
36 |
+
|
37 |
+
B, nSupp, C, H, W = supp_x.shape
|
38 |
+
supp_f = self.backbone.forward(supp_x.view(-1, C, H, W))
|
39 |
+
supp_f = supp_f.view(B, nSupp, -1)
|
40 |
+
|
41 |
+
supp_y_1hot = F.one_hot(supp_y, num_classes).transpose(1, 2) # B, nC, nSupp
|
42 |
+
|
43 |
+
# B, nC, nSupp x B, nSupp, d = B, nC, d
|
44 |
+
prototypes = torch.bmm(supp_y_1hot.float(), supp_f)
|
45 |
+
prototypes = prototypes / supp_y_1hot.sum(dim=2, keepdim=True) # NOTE: may div 0 if some classes got 0 images
|
46 |
+
|
47 |
+
feat = self.backbone.forward(x.view(-1, C, H, W))
|
48 |
+
feat = feat.view(B, x.shape[1], -1) # B, nQry, d
|
49 |
+
|
50 |
+
logits = self.cos_classifier(prototypes, feat) # B, nQry, nC
|
51 |
+
return logits
|
models/resnet_v2.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 Google LLC
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# Lint as: python3
|
16 |
+
"""Bottleneck ResNet v2 with GroupNorm and Weight Standardization."""
|
17 |
+
import math
|
18 |
+
|
19 |
+
from os.path import join as pjoin
|
20 |
+
|
21 |
+
from collections import OrderedDict # pylint: disable=g-importing-member
|
22 |
+
|
23 |
+
import torch
|
24 |
+
import torch.nn as nn
|
25 |
+
import torch.nn.functional as F
|
26 |
+
|
27 |
+
|
28 |
+
def np2th(weights, conv=False):
|
29 |
+
"""Possibly convert HWIO to OIHW."""
|
30 |
+
if conv:
|
31 |
+
weights = weights.transpose([3, 2, 0, 1])
|
32 |
+
return torch.from_numpy(weights)
|
33 |
+
|
34 |
+
|
35 |
+
class StdConv2d(nn.Conv2d):
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
w = self.weight
|
39 |
+
v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)
|
40 |
+
w = (w - m) / torch.sqrt(v + 1e-5)
|
41 |
+
return F.conv2d(x, w, self.bias, self.stride, self.padding,
|
42 |
+
self.dilation, self.groups)
|
43 |
+
|
44 |
+
|
45 |
+
def conv3x3(cin, cout, stride=1, groups=1, bias=False):
|
46 |
+
return StdConv2d(cin, cout, kernel_size=3, stride=stride,
|
47 |
+
padding=1, bias=bias, groups=groups)
|
48 |
+
|
49 |
+
|
50 |
+
def conv1x1(cin, cout, stride=1, bias=False):
|
51 |
+
return StdConv2d(cin, cout, kernel_size=1, stride=stride,
|
52 |
+
padding=0, bias=bias)
|
53 |
+
|
54 |
+
|
55 |
+
class PreActBottleneck(nn.Module):
|
56 |
+
"""Pre-activation (v2) bottleneck block.
|
57 |
+
"""
|
58 |
+
|
59 |
+
def __init__(self, cin, cout=None, cmid=None, stride=1):
|
60 |
+
super().__init__()
|
61 |
+
cout = cout or cin
|
62 |
+
cmid = cmid or cout//4
|
63 |
+
|
64 |
+
self.gn1 = nn.GroupNorm(32, cmid, eps=1e-6)
|
65 |
+
self.conv1 = conv1x1(cin, cmid, bias=False)
|
66 |
+
self.gn2 = nn.GroupNorm(32, cmid, eps=1e-6)
|
67 |
+
self.conv2 = conv3x3(cmid, cmid, stride, bias=False) # Original code has it on conv1!!
|
68 |
+
self.gn3 = nn.GroupNorm(32, cout, eps=1e-6)
|
69 |
+
self.conv3 = conv1x1(cmid, cout, bias=False)
|
70 |
+
self.relu = nn.ReLU(inplace=True)
|
71 |
+
|
72 |
+
if (stride != 1 or cin != cout):
|
73 |
+
# Projection also with pre-activation according to paper.
|
74 |
+
self.downsample = conv1x1(cin, cout, stride, bias=False)
|
75 |
+
self.gn_proj = nn.GroupNorm(cout, cout)
|
76 |
+
|
77 |
+
def forward(self, x):
|
78 |
+
|
79 |
+
# Residual branch
|
80 |
+
residual = x
|
81 |
+
if hasattr(self, 'downsample'):
|
82 |
+
residual = self.downsample(x)
|
83 |
+
residual = self.gn_proj(residual)
|
84 |
+
|
85 |
+
# Unit's branch
|
86 |
+
y = self.relu(self.gn1(self.conv1(x)))
|
87 |
+
y = self.relu(self.gn2(self.conv2(y)))
|
88 |
+
y = self.gn3(self.conv3(y))
|
89 |
+
|
90 |
+
y = self.relu(residual + y)
|
91 |
+
return y
|
92 |
+
|
93 |
+
def load_from(self, weights, n_block, n_unit):
|
94 |
+
conv1_weight = np2th(weights[pjoin(n_block, n_unit, "conv1/kernel")], conv=True)
|
95 |
+
conv2_weight = np2th(weights[pjoin(n_block, n_unit, "conv2/kernel")], conv=True)
|
96 |
+
conv3_weight = np2th(weights[pjoin(n_block, n_unit, "conv3/kernel")], conv=True)
|
97 |
+
|
98 |
+
gn1_weight = np2th(weights[pjoin(n_block, n_unit, "gn1/scale")])
|
99 |
+
gn1_bias = np2th(weights[pjoin(n_block, n_unit, "gn1/bias")])
|
100 |
+
|
101 |
+
gn2_weight = np2th(weights[pjoin(n_block, n_unit, "gn2/scale")])
|
102 |
+
gn2_bias = np2th(weights[pjoin(n_block, n_unit, "gn2/bias")])
|
103 |
+
|
104 |
+
gn3_weight = np2th(weights[pjoin(n_block, n_unit, "gn3/scale")])
|
105 |
+
gn3_bias = np2th(weights[pjoin(n_block, n_unit, "gn3/bias")])
|
106 |
+
|
107 |
+
self.conv1.weight.copy_(conv1_weight)
|
108 |
+
self.conv2.weight.copy_(conv2_weight)
|
109 |
+
self.conv3.weight.copy_(conv3_weight)
|
110 |
+
|
111 |
+
self.gn1.weight.copy_(gn1_weight.view(-1))
|
112 |
+
self.gn1.bias.copy_(gn1_bias.view(-1))
|
113 |
+
|
114 |
+
self.gn2.weight.copy_(gn2_weight.view(-1))
|
115 |
+
self.gn2.bias.copy_(gn2_bias.view(-1))
|
116 |
+
|
117 |
+
self.gn3.weight.copy_(gn3_weight.view(-1))
|
118 |
+
self.gn3.bias.copy_(gn3_bias.view(-1))
|
119 |
+
|
120 |
+
if hasattr(self, 'downsample'):
|
121 |
+
proj_conv_weight = np2th(weights[pjoin(n_block, n_unit, "conv_proj/kernel")], conv=True)
|
122 |
+
proj_gn_weight = np2th(weights[pjoin(n_block, n_unit, "gn_proj/scale")])
|
123 |
+
proj_gn_bias = np2th(weights[pjoin(n_block, n_unit, "gn_proj/bias")])
|
124 |
+
|
125 |
+
self.downsample.weight.copy_(proj_conv_weight)
|
126 |
+
self.gn_proj.weight.copy_(proj_gn_weight.view(-1))
|
127 |
+
self.gn_proj.bias.copy_(proj_gn_bias.view(-1))
|
128 |
+
|
129 |
+
class ResNetV2(nn.Module):
|
130 |
+
"""Implementation of Pre-activation (v2) ResNet mode."""
|
131 |
+
|
132 |
+
def __init__(self, block_units, width_factor):
|
133 |
+
super().__init__()
|
134 |
+
width = int(64 * width_factor)
|
135 |
+
self.width = width
|
136 |
+
|
137 |
+
# The following will be unreadable if we split lines.
|
138 |
+
# pylint: disable=line-too-long
|
139 |
+
self.root = nn.Sequential(OrderedDict([
|
140 |
+
('conv', StdConv2d(3, width, kernel_size=7, stride=2, bias=False, padding=3)),
|
141 |
+
('gn', nn.GroupNorm(32, width, eps=1e-6)),
|
142 |
+
('relu', nn.ReLU(inplace=True)),
|
143 |
+
('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0))
|
144 |
+
]))
|
145 |
+
|
146 |
+
self.body = nn.Sequential(OrderedDict([
|
147 |
+
('block1', nn.Sequential(OrderedDict(
|
148 |
+
[('unit1', PreActBottleneck(cin=width, cout=width*4, cmid=width))] +
|
149 |
+
[(f'unit{i:d}', PreActBottleneck(cin=width*4, cout=width*4, cmid=width)) for i in range(2, block_units[0] + 1)],
|
150 |
+
))),
|
151 |
+
('block2', nn.Sequential(OrderedDict(
|
152 |
+
[('unit1', PreActBottleneck(cin=width*4, cout=width*8, cmid=width*2, stride=2))] +
|
153 |
+
[(f'unit{i:d}', PreActBottleneck(cin=width*8, cout=width*8, cmid=width*2)) for i in range(2, block_units[1] + 1)],
|
154 |
+
))),
|
155 |
+
('block3', nn.Sequential(OrderedDict(
|
156 |
+
[('unit1', PreActBottleneck(cin=width*8, cout=width*16, cmid=width*4, stride=2))] +
|
157 |
+
[(f'unit{i:d}', PreActBottleneck(cin=width*16, cout=width*16, cmid=width*4)) for i in range(2, block_units[2] + 1)],
|
158 |
+
))),
|
159 |
+
]))
|
160 |
+
|
161 |
+
def forward(self, x):
|
162 |
+
x = self.root(x)
|
163 |
+
x = self.body(x)
|
164 |
+
return x
|
models/utils.py
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import warnings
|
4 |
+
import ml_collections
|
5 |
+
import random
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
|
9 |
+
def DiffAugment(x, types=[], prob = 0.5, detach=True):
|
10 |
+
"""
|
11 |
+
x.shape = B, C, H, W
|
12 |
+
"""
|
13 |
+
if random.random() < prob:
|
14 |
+
with torch.set_grad_enabled(not detach):
|
15 |
+
x = random_hflip(x, prob=0.5)
|
16 |
+
for p in types:
|
17 |
+
for f in AUGMENT_FNS[p]:
|
18 |
+
x = f(x)
|
19 |
+
x = x.contiguous()
|
20 |
+
return x
|
21 |
+
|
22 |
+
|
23 |
+
def random_hflip(tensor, prob):
|
24 |
+
if prob > random.random():
|
25 |
+
return tensor
|
26 |
+
return torch.flip(tensor, dims=(3,))
|
27 |
+
|
28 |
+
def rand_brightness(x):
|
29 |
+
x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
|
30 |
+
return x
|
31 |
+
|
32 |
+
def rand_saturation(x):
|
33 |
+
x_mean = x.mean(dim=1, keepdim=True)
|
34 |
+
x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
|
35 |
+
return x
|
36 |
+
|
37 |
+
def rand_contrast(x):
|
38 |
+
x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
|
39 |
+
x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
|
40 |
+
return x
|
41 |
+
|
42 |
+
def rand_translation(x, ratio=0.125):
|
43 |
+
shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
|
44 |
+
translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
|
45 |
+
translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
|
46 |
+
grid_batch, grid_x, grid_y = torch.meshgrid(
|
47 |
+
torch.arange(x.size(0), dtype=torch.long, device=x.device),
|
48 |
+
torch.arange(x.size(2), dtype=torch.long, device=x.device),
|
49 |
+
torch.arange(x.size(3), dtype=torch.long, device=x.device),
|
50 |
+
)
|
51 |
+
grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
|
52 |
+
grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
|
53 |
+
x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
|
54 |
+
x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
|
55 |
+
return x
|
56 |
+
|
57 |
+
def rand_offset(x, ratio=1, ratio_h=1, ratio_v=1):
|
58 |
+
w, h = x.size(2), x.size(3)
|
59 |
+
|
60 |
+
imgs = []
|
61 |
+
for img in x.unbind(dim = 0):
|
62 |
+
max_h = int(w * ratio * ratio_h)
|
63 |
+
max_v = int(h * ratio * ratio_v)
|
64 |
+
|
65 |
+
value_h = random.randint(0, max_h) * 2 - max_h
|
66 |
+
value_v = random.randint(0, max_v) * 2 - max_v
|
67 |
+
|
68 |
+
if abs(value_h) > 0:
|
69 |
+
img = torch.roll(img, value_h, 2)
|
70 |
+
|
71 |
+
if abs(value_v) > 0:
|
72 |
+
img = torch.roll(img, value_v, 1)
|
73 |
+
|
74 |
+
imgs.append(img)
|
75 |
+
|
76 |
+
return torch.stack(imgs)
|
77 |
+
|
78 |
+
def rand_offset_h(x, ratio=1):
|
79 |
+
return rand_offset(x, ratio=1, ratio_h=ratio, ratio_v=0)
|
80 |
+
|
81 |
+
def rand_offset_v(x, ratio=1):
|
82 |
+
return rand_offset(x, ratio=1, ratio_h=0, ratio_v=ratio)
|
83 |
+
|
84 |
+
def rand_cutout(x, ratio=0.5):
|
85 |
+
cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
|
86 |
+
offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
|
87 |
+
offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
|
88 |
+
grid_batch, grid_x, grid_y = torch.meshgrid(
|
89 |
+
torch.arange(x.size(0), dtype=torch.long, device=x.device),
|
90 |
+
torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
|
91 |
+
torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
|
92 |
+
)
|
93 |
+
grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
|
94 |
+
grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
|
95 |
+
mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
|
96 |
+
mask[grid_batch, grid_x, grid_y] = 0
|
97 |
+
x = x * mask.unsqueeze(1)
|
98 |
+
return x
|
99 |
+
|
100 |
+
|
101 |
+
AUGMENT_FNS = {
|
102 |
+
'color': [rand_brightness, rand_saturation, rand_contrast],
|
103 |
+
'offset': [rand_offset],
|
104 |
+
'offset_h': [rand_offset_h],
|
105 |
+
'offset_v': [rand_offset_v],
|
106 |
+
'translation': [rand_translation],
|
107 |
+
'cutout': [rand_cutout],
|
108 |
+
}
|
109 |
+
|
110 |
+
|
111 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
112 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
113 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
114 |
+
def norm_cdf(x):
|
115 |
+
# Computes standard normal cumulative distribution function
|
116 |
+
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
117 |
+
|
118 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
119 |
+
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
120 |
+
"The distribution of values may be incorrect.",
|
121 |
+
stacklevel=2)
|
122 |
+
|
123 |
+
with torch.no_grad():
|
124 |
+
# Values are generated by using a truncated uniform distribution and
|
125 |
+
# then using the inverse CDF for the normal distribution.
|
126 |
+
# Get upper and lower cdf values
|
127 |
+
l = norm_cdf((a - mean) / std)
|
128 |
+
u = norm_cdf((b - mean) / std)
|
129 |
+
|
130 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
131 |
+
# [2l-1, 2u-1].
|
132 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
133 |
+
|
134 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
135 |
+
# standard normal
|
136 |
+
tensor.erfinv_()
|
137 |
+
|
138 |
+
# Transform to proper mean, std
|
139 |
+
tensor.mul_(std * math.sqrt(2.))
|
140 |
+
tensor.add_(mean)
|
141 |
+
|
142 |
+
# Clamp to ensure it's in the proper range
|
143 |
+
tensor.clamp_(min=a, max=b)
|
144 |
+
return tensor
|
145 |
+
|
146 |
+
|
147 |
+
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
148 |
+
# type: (Tensor, float, float, float, float) -> Tensor
|
149 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
150 |
+
|
151 |
+
|
152 |
+
def get_testing():
|
153 |
+
"""Returns a minimal configuration for testing."""
|
154 |
+
config = ml_collections.ConfigDict()
|
155 |
+
config.patches = ml_collections.ConfigDict({'size': (16, 16)})
|
156 |
+
config.hidden_size = 1
|
157 |
+
config.transformer = ml_collections.ConfigDict()
|
158 |
+
config.transformer.mlp_dim = 1
|
159 |
+
config.transformer.num_heads = 1
|
160 |
+
config.transformer.num_layers = 1
|
161 |
+
config.transformer.attention_dropout_rate = 0.0
|
162 |
+
config.transformer.dropout_rate = 0.1
|
163 |
+
config.classifier = 'token'
|
164 |
+
config.representation_size = None
|
165 |
+
return config
|
166 |
+
|
167 |
+
|
168 |
+
def get_b16_config():
|
169 |
+
"""Returns the ViT-B/16 configuration."""
|
170 |
+
config = ml_collections.ConfigDict()
|
171 |
+
config.patches = ml_collections.ConfigDict({'size': (16, 16)})
|
172 |
+
config.hidden_size = 768
|
173 |
+
config.transformer = ml_collections.ConfigDict()
|
174 |
+
config.transformer.mlp_dim = 3072
|
175 |
+
config.transformer.num_heads = 12
|
176 |
+
config.transformer.num_layers = 12
|
177 |
+
config.transformer.attention_dropout_rate = 0.0
|
178 |
+
config.transformer.dropout_rate = 0.1
|
179 |
+
config.classifier = 'token'
|
180 |
+
config.representation_size = None
|
181 |
+
return config
|
182 |
+
|
183 |
+
|
184 |
+
def get_r50_b16_config():
|
185 |
+
"""Returns the Resnet50 + ViT-B/16 configuration."""
|
186 |
+
config = get_b16_config()
|
187 |
+
del config.patches.size
|
188 |
+
config.patches.grid = (14, 14)
|
189 |
+
config.resnet = ml_collections.ConfigDict()
|
190 |
+
config.resnet.num_layers = (3, 4, 9)
|
191 |
+
config.resnet.width_factor = 1
|
192 |
+
return config
|
193 |
+
|
194 |
+
|
195 |
+
def get_b32_config():
|
196 |
+
"""Returns the ViT-B/32 configuration."""
|
197 |
+
config = get_b16_config()
|
198 |
+
config.patches.size = (32, 32)
|
199 |
+
return config
|
200 |
+
|
201 |
+
|
202 |
+
def get_l16_config():
|
203 |
+
"""Returns the ViT-L/16 configuration."""
|
204 |
+
config = ml_collections.ConfigDict()
|
205 |
+
config.patches = ml_collections.ConfigDict({'size': (16, 16)})
|
206 |
+
config.hidden_size = 1024
|
207 |
+
config.transformer = ml_collections.ConfigDict()
|
208 |
+
config.transformer.mlp_dim = 4096
|
209 |
+
config.transformer.num_heads = 16
|
210 |
+
config.transformer.num_layers = 24
|
211 |
+
config.transformer.attention_dropout_rate = 0.0
|
212 |
+
config.transformer.dropout_rate = 0.1
|
213 |
+
config.classifier = 'token'
|
214 |
+
config.representation_size = None
|
215 |
+
return config
|
216 |
+
|
217 |
+
|
218 |
+
def get_l32_config():
|
219 |
+
"""Returns the ViT-L/32 configuration."""
|
220 |
+
config = get_l16_config()
|
221 |
+
config.patches.size = (32, 32)
|
222 |
+
return config
|
223 |
+
|
224 |
+
|
225 |
+
def get_h14_config():
|
226 |
+
"""Returns the ViT-L/16 configuration."""
|
227 |
+
config = ml_collections.ConfigDict()
|
228 |
+
config.patches = ml_collections.ConfigDict({'size': (14, 14)})
|
229 |
+
config.hidden_size = 1280
|
230 |
+
config.transformer = ml_collections.ConfigDict()
|
231 |
+
config.transformer.mlp_dim = 5120
|
232 |
+
config.transformer.num_heads = 16
|
233 |
+
config.transformer.num_layers = 32
|
234 |
+
config.transformer.attention_dropout_rate = 0.0
|
235 |
+
config.transformer.dropout_rate = 0.1
|
236 |
+
config.classifier = 'token'
|
237 |
+
config.representation_size = None
|
238 |
+
return config
|
models/vision_transformer.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
import math
|
5 |
+
from functools import partial
|
6 |
+
from .utils import trunc_normal_
|
7 |
+
|
8 |
+
|
9 |
+
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
10 |
+
if drop_prob == 0. or not training:
|
11 |
+
return x
|
12 |
+
keep_prob = 1 - drop_prob
|
13 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
14 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
15 |
+
random_tensor.floor_() # binarize
|
16 |
+
output = x.div(keep_prob) * random_tensor
|
17 |
+
return output
|
18 |
+
|
19 |
+
|
20 |
+
class DropPath(nn.Module):
|
21 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
22 |
+
"""
|
23 |
+
def __init__(self, drop_prob=None):
|
24 |
+
super(DropPath, self).__init__()
|
25 |
+
self.drop_prob = drop_prob
|
26 |
+
|
27 |
+
def forward(self, x):
|
28 |
+
return drop_path(x, self.drop_prob, self.training)
|
29 |
+
|
30 |
+
|
31 |
+
class Mlp(nn.Module):
|
32 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
33 |
+
super().__init__()
|
34 |
+
out_features = out_features or in_features
|
35 |
+
hidden_features = hidden_features or in_features
|
36 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
37 |
+
self.act = act_layer()
|
38 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
39 |
+
self.drop = nn.Dropout(drop)
|
40 |
+
|
41 |
+
def forward(self, x):
|
42 |
+
x = self.fc1(x)
|
43 |
+
x = self.act(x)
|
44 |
+
x = self.drop(x)
|
45 |
+
x = self.fc2(x)
|
46 |
+
x = self.drop(x)
|
47 |
+
return x
|
48 |
+
|
49 |
+
|
50 |
+
class Attention(nn.Module):
|
51 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
52 |
+
super().__init__()
|
53 |
+
self.num_heads = num_heads
|
54 |
+
head_dim = dim // num_heads
|
55 |
+
self.scale = qk_scale or head_dim ** -0.5
|
56 |
+
|
57 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
58 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
59 |
+
self.proj = nn.Linear(dim, dim)
|
60 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
61 |
+
|
62 |
+
def forward(self, x):
|
63 |
+
B, N, C = x.shape
|
64 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
65 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
66 |
+
|
67 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
68 |
+
attn = attn.softmax(dim=-1)
|
69 |
+
attn = self.attn_drop(attn)
|
70 |
+
|
71 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
72 |
+
x = self.proj(x)
|
73 |
+
x = self.proj_drop(x)
|
74 |
+
return x, attn
|
75 |
+
|
76 |
+
|
77 |
+
class Block(nn.Module):
|
78 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
79 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
80 |
+
super().__init__()
|
81 |
+
self.norm1 = norm_layer(dim)
|
82 |
+
self.attn = Attention(
|
83 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
84 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
85 |
+
self.norm2 = norm_layer(dim)
|
86 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
87 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
88 |
+
|
89 |
+
def forward(self, x, return_attention=False):
|
90 |
+
y, attn = self.attn(self.norm1(x))
|
91 |
+
if return_attention:
|
92 |
+
return attn
|
93 |
+
x = x + self.drop_path(y)
|
94 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
95 |
+
return x
|
96 |
+
|
97 |
+
|
98 |
+
class PatchEmbed(nn.Module):
|
99 |
+
""" Image to Patch Embedding
|
100 |
+
"""
|
101 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
102 |
+
super().__init__()
|
103 |
+
num_patches = (img_size // patch_size) * (img_size // patch_size)
|
104 |
+
self.img_size = img_size
|
105 |
+
self.patch_size = patch_size
|
106 |
+
self.num_patches = num_patches
|
107 |
+
|
108 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
109 |
+
|
110 |
+
def forward(self, x):
|
111 |
+
B, C, H, W = x.shape
|
112 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
113 |
+
return x
|
114 |
+
|
115 |
+
|
116 |
+
class VisionTransformer(nn.Module):
|
117 |
+
""" Vision Transformer """
|
118 |
+
def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12,
|
119 |
+
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
|
120 |
+
drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs):
|
121 |
+
super().__init__()
|
122 |
+
self.num_features = self.embed_dim = embed_dim
|
123 |
+
|
124 |
+
self.patch_embed = PatchEmbed(
|
125 |
+
img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
126 |
+
num_patches = self.patch_embed.num_patches
|
127 |
+
|
128 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
129 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
130 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
131 |
+
|
132 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
133 |
+
self.blocks = nn.ModuleList([
|
134 |
+
Block(
|
135 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
136 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
|
137 |
+
for i in range(depth)])
|
138 |
+
self.norm = norm_layer(embed_dim)
|
139 |
+
|
140 |
+
# Classifier head
|
141 |
+
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
142 |
+
|
143 |
+
trunc_normal_(self.pos_embed, std=.02)
|
144 |
+
trunc_normal_(self.cls_token, std=.02)
|
145 |
+
self.apply(self._init_weights)
|
146 |
+
|
147 |
+
def _init_weights(self, m):
|
148 |
+
if isinstance(m, nn.Linear):
|
149 |
+
trunc_normal_(m.weight, std=.02)
|
150 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
151 |
+
nn.init.constant_(m.bias, 0)
|
152 |
+
elif isinstance(m, nn.LayerNorm):
|
153 |
+
nn.init.constant_(m.bias, 0)
|
154 |
+
nn.init.constant_(m.weight, 1.0)
|
155 |
+
|
156 |
+
def interpolate_pos_encoding(self, x, w, h):
|
157 |
+
npatch = x.shape[1] - 1
|
158 |
+
N = self.pos_embed.shape[1] - 1
|
159 |
+
if npatch == N and w == h:
|
160 |
+
return self.pos_embed
|
161 |
+
class_pos_embed = self.pos_embed[:, 0]
|
162 |
+
patch_pos_embed = self.pos_embed[:, 1:]
|
163 |
+
dim = x.shape[-1]
|
164 |
+
w0 = w // self.patch_embed.patch_size
|
165 |
+
h0 = h // self.patch_embed.patch_size
|
166 |
+
# we add a small number to avoid floating point error in the interpolation
|
167 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
168 |
+
w0, h0 = w0 + 0.1, h0 + 0.1
|
169 |
+
patch_pos_embed = nn.functional.interpolate(
|
170 |
+
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
|
171 |
+
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
|
172 |
+
mode='bicubic',
|
173 |
+
align_corners=False,
|
174 |
+
recompute_scale_factor=False
|
175 |
+
)
|
176 |
+
assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
|
177 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
178 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
179 |
+
|
180 |
+
def prepare_tokens(self, x, ada_token=None):
|
181 |
+
B, nc, w, h = x.shape
|
182 |
+
x = self.patch_embed(x) # patch linear embedding
|
183 |
+
|
184 |
+
# add the [CLS] token to the embed patch tokens
|
185 |
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
186 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
187 |
+
|
188 |
+
# add positional encoding to each token
|
189 |
+
x = x + self.interpolate_pos_encoding(x, w, h)
|
190 |
+
|
191 |
+
if ada_token is not None:
|
192 |
+
ada_tokens = ada_token.expand(B, -1, -1) # B, p, d
|
193 |
+
x = torch.cat((x, ada_tokens), dim=1)
|
194 |
+
|
195 |
+
return self.pos_drop(x)
|
196 |
+
|
197 |
+
def forward(self, x, ada_token=None, use_patches=False):
|
198 |
+
x = self.prepare_tokens(x, ada_token)
|
199 |
+
for blk in self.blocks:
|
200 |
+
x = blk(x)
|
201 |
+
x = self.norm(x)
|
202 |
+
|
203 |
+
if use_patches:
|
204 |
+
return x[:, 1:]
|
205 |
+
else:
|
206 |
+
return x[:, 0]
|
207 |
+
|
208 |
+
def get_last_selfattention(self, x):
|
209 |
+
x = self.prepare_tokens(x)
|
210 |
+
for i, blk in enumerate(self.blocks):
|
211 |
+
if i < len(self.blocks) - 1:
|
212 |
+
x = blk(x)
|
213 |
+
else:
|
214 |
+
# return attention of the last block
|
215 |
+
return blk(x, return_attention=True)
|
216 |
+
|
217 |
+
def get_intermediate_layers(self, x, n=1):
|
218 |
+
x = self.prepare_tokens(x)
|
219 |
+
# we return the output tokens from the `n` last blocks
|
220 |
+
output = []
|
221 |
+
for i, blk in enumerate(self.blocks):
|
222 |
+
x = blk(x)
|
223 |
+
if len(self.blocks) - i <= n:
|
224 |
+
output.append(self.norm(x))
|
225 |
+
return output
|
226 |
+
|
227 |
+
|
228 |
+
def vit_tiny(patch_size=16, **kwargs):
|
229 |
+
model = VisionTransformer(
|
230 |
+
patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,
|
231 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
232 |
+
return model
|
233 |
+
|
234 |
+
|
235 |
+
def vit_small(patch_size=16, **kwargs):
|
236 |
+
model = VisionTransformer(
|
237 |
+
patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
|
238 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
239 |
+
return model
|
240 |
+
|
241 |
+
|
242 |
+
def vit_base(patch_size=16, **kwargs):
|
243 |
+
model = VisionTransformer(
|
244 |
+
patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
|
245 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
246 |
+
return model
|
models/vit_google.py
ADDED
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import logging
|
3 |
+
import math
|
4 |
+
|
5 |
+
from os.path import join as pjoin
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
from torch.nn import Dropout, Softmax, Linear, Conv2d, LayerNorm
|
12 |
+
from torch.nn.modules.utils import _pair
|
13 |
+
from scipy import ndimage
|
14 |
+
|
15 |
+
from .utils import get_b16_config
|
16 |
+
from .resnet_v2 import ResNetV2
|
17 |
+
|
18 |
+
|
19 |
+
CONFIGS = {
|
20 |
+
'ViT-B_16': get_b16_config(),
|
21 |
+
#'ViT-B_32': get_b32_config(),
|
22 |
+
#'ViT-L_16': get_l16_config(),
|
23 |
+
#'ViT-L_32': get_l32_config(),
|
24 |
+
#'ViT-H_14': get_h14_config(),
|
25 |
+
#'R50-ViT-B_16': get_r50_b16_config(),
|
26 |
+
#'testing': configs.get_testing(),
|
27 |
+
}
|
28 |
+
|
29 |
+
ATTENTION_Q = "MultiHeadDotProductAttention_1/query"
|
30 |
+
ATTENTION_K = "MultiHeadDotProductAttention_1/key"
|
31 |
+
ATTENTION_V = "MultiHeadDotProductAttention_1/value"
|
32 |
+
ATTENTION_OUT = "MultiHeadDotProductAttention_1/out"
|
33 |
+
FC_0 = "MlpBlock_3/Dense_0"
|
34 |
+
FC_1 = "MlpBlock_3/Dense_1"
|
35 |
+
ATTENTION_NORM = "LayerNorm_0"
|
36 |
+
MLP_NORM = "LayerNorm_2"
|
37 |
+
|
38 |
+
|
39 |
+
def np2th(weights, conv=False):
|
40 |
+
"""Possibly convert HWIO to OIHW."""
|
41 |
+
if conv:
|
42 |
+
weights = weights.transpose([3, 2, 0, 1])
|
43 |
+
return torch.from_numpy(weights)
|
44 |
+
|
45 |
+
|
46 |
+
def swish(x):
|
47 |
+
return x * torch.sigmoid(x)
|
48 |
+
|
49 |
+
|
50 |
+
ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish}
|
51 |
+
|
52 |
+
|
53 |
+
class Attention(nn.Module):
|
54 |
+
def __init__(self, config, vis):
|
55 |
+
super(Attention, self).__init__()
|
56 |
+
self.vis = vis
|
57 |
+
self.num_attention_heads = config.transformer["num_heads"]
|
58 |
+
self.attention_head_size = int(config.hidden_size / self.num_attention_heads)
|
59 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
60 |
+
|
61 |
+
self.query = Linear(config.hidden_size, self.all_head_size)
|
62 |
+
self.key = Linear(config.hidden_size, self.all_head_size)
|
63 |
+
self.value = Linear(config.hidden_size, self.all_head_size)
|
64 |
+
|
65 |
+
self.out = Linear(config.hidden_size, config.hidden_size)
|
66 |
+
self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"])
|
67 |
+
self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"])
|
68 |
+
|
69 |
+
self.softmax = Softmax(dim=-1)
|
70 |
+
|
71 |
+
def transpose_for_scores(self, x):
|
72 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
73 |
+
x = x.view(*new_x_shape)
|
74 |
+
return x.permute(0, 2, 1, 3)
|
75 |
+
|
76 |
+
def forward(self, hidden_states):
|
77 |
+
mixed_query_layer = self.query(hidden_states)
|
78 |
+
mixed_key_layer = self.key(hidden_states)
|
79 |
+
mixed_value_layer = self.value(hidden_states)
|
80 |
+
|
81 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
82 |
+
key_layer = self.transpose_for_scores(mixed_key_layer)
|
83 |
+
value_layer = self.transpose_for_scores(mixed_value_layer)
|
84 |
+
|
85 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
86 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
87 |
+
attention_probs = self.softmax(attention_scores)
|
88 |
+
weights = attention_probs if self.vis else None
|
89 |
+
attention_probs = self.attn_dropout(attention_probs)
|
90 |
+
|
91 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
92 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
93 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
94 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
95 |
+
attention_output = self.out(context_layer)
|
96 |
+
attention_output = self.proj_dropout(attention_output)
|
97 |
+
return attention_output, weights
|
98 |
+
|
99 |
+
|
100 |
+
class Mlp(nn.Module):
|
101 |
+
def __init__(self, config):
|
102 |
+
super(Mlp, self).__init__()
|
103 |
+
self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"])
|
104 |
+
self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size)
|
105 |
+
self.act_fn = ACT2FN["gelu"]
|
106 |
+
self.dropout = Dropout(config.transformer["dropout_rate"])
|
107 |
+
|
108 |
+
self._init_weights()
|
109 |
+
|
110 |
+
def _init_weights(self):
|
111 |
+
nn.init.xavier_uniform_(self.fc1.weight)
|
112 |
+
nn.init.xavier_uniform_(self.fc2.weight)
|
113 |
+
nn.init.normal_(self.fc1.bias, std=1e-6)
|
114 |
+
nn.init.normal_(self.fc2.bias, std=1e-6)
|
115 |
+
|
116 |
+
def forward(self, x):
|
117 |
+
x = self.fc1(x)
|
118 |
+
x = self.act_fn(x)
|
119 |
+
x = self.dropout(x)
|
120 |
+
x = self.fc2(x)
|
121 |
+
x = self.dropout(x)
|
122 |
+
return x
|
123 |
+
|
124 |
+
|
125 |
+
class Embeddings(nn.Module):
|
126 |
+
"""Construct the embeddings from patch, position embeddings.
|
127 |
+
"""
|
128 |
+
def __init__(self, config, img_size, in_channels=3):
|
129 |
+
super(Embeddings, self).__init__()
|
130 |
+
self.hybrid = None
|
131 |
+
img_size = _pair(img_size)
|
132 |
+
|
133 |
+
if config.patches.get("grid") is not None:
|
134 |
+
grid_size = config.patches["grid"]
|
135 |
+
patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1])
|
136 |
+
n_patches = (img_size[0] // 16) * (img_size[1] // 16)
|
137 |
+
self.hybrid = True
|
138 |
+
else:
|
139 |
+
patch_size = _pair(config.patches["size"])
|
140 |
+
n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
|
141 |
+
self.hybrid = False
|
142 |
+
|
143 |
+
if self.hybrid:
|
144 |
+
self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers,
|
145 |
+
width_factor=config.resnet.width_factor)
|
146 |
+
in_channels = self.hybrid_model.width * 16
|
147 |
+
self.patch_size = patch_size
|
148 |
+
self.patch_embeddings = Conv2d(in_channels=in_channels,
|
149 |
+
out_channels=config.hidden_size,
|
150 |
+
kernel_size=patch_size,
|
151 |
+
stride=patch_size)
|
152 |
+
self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches+1, config.hidden_size))
|
153 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
154 |
+
|
155 |
+
self.dropout = Dropout(config.transformer["dropout_rate"])
|
156 |
+
|
157 |
+
def interpolate_pos_encoding(self, x, h, w):
|
158 |
+
npatch = x.shape[1] - 1
|
159 |
+
N = self.position_embeddings.shape[1] - 1
|
160 |
+
if npatch == N and w == h:
|
161 |
+
return self.position_embeddings
|
162 |
+
class_pos_embed = self.position_embeddings[:, 0]
|
163 |
+
patch_pos_embed = self.position_embeddings[:, 1:]
|
164 |
+
dim = x.shape[-1]
|
165 |
+
w0 = w // self.patch_size[0]
|
166 |
+
h0 = h // self.patch_size[1]
|
167 |
+
# we add a small number to avoid floating point error in the interpolation
|
168 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
169 |
+
w0, h0 = w0 + 0.1, h0 + 0.1
|
170 |
+
patch_pos_embed = nn.functional.interpolate(
|
171 |
+
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
|
172 |
+
scale_factor=(h0 / math.sqrt(N), w0 / math.sqrt(N)),
|
173 |
+
mode='bicubic',
|
174 |
+
align_corners=False,
|
175 |
+
recompute_scale_factor=False
|
176 |
+
)
|
177 |
+
assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1]
|
178 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
179 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
180 |
+
|
181 |
+
def forward(self, x):
|
182 |
+
B, nc, h, w = x.shape
|
183 |
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
184 |
+
|
185 |
+
if self.hybrid:
|
186 |
+
x = self.hybrid_model(x)
|
187 |
+
|
188 |
+
# Linear embedding
|
189 |
+
x = self.patch_embeddings(x)
|
190 |
+
|
191 |
+
# add the [CLS] token to the embed patch tokens
|
192 |
+
x = x.flatten(2)
|
193 |
+
x = x.transpose(-1, -2)
|
194 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
195 |
+
|
196 |
+
# add positional encoding to each token
|
197 |
+
embeddings = x + self.interpolate_pos_encoding(x, h, w)
|
198 |
+
embeddings = self.dropout(embeddings)
|
199 |
+
return embeddings
|
200 |
+
|
201 |
+
|
202 |
+
class Block(nn.Module):
|
203 |
+
def __init__(self, config, vis):
|
204 |
+
super(Block, self).__init__()
|
205 |
+
self.hidden_size = config.hidden_size
|
206 |
+
self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6)
|
207 |
+
self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6)
|
208 |
+
self.ffn = Mlp(config)
|
209 |
+
self.attn = Attention(config, vis)
|
210 |
+
|
211 |
+
def forward(self, x):
|
212 |
+
h = x
|
213 |
+
x = self.attention_norm(x)
|
214 |
+
x, weights = self.attn(x)
|
215 |
+
x = x + h
|
216 |
+
|
217 |
+
h = x
|
218 |
+
x = self.ffn_norm(x)
|
219 |
+
x = self.ffn(x)
|
220 |
+
x = x + h
|
221 |
+
return x, weights
|
222 |
+
|
223 |
+
def load_from(self, weights, n_block):
|
224 |
+
ROOT = f"Transformer/encoderblock_{n_block}"
|
225 |
+
with torch.no_grad():
|
226 |
+
query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t()
|
227 |
+
key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t()
|
228 |
+
value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, self.hidden_size).t()
|
229 |
+
out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, self.hidden_size).t()
|
230 |
+
|
231 |
+
query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1)
|
232 |
+
key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1)
|
233 |
+
value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1)
|
234 |
+
out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1)
|
235 |
+
|
236 |
+
self.attn.query.weight.copy_(query_weight)
|
237 |
+
self.attn.key.weight.copy_(key_weight)
|
238 |
+
self.attn.value.weight.copy_(value_weight)
|
239 |
+
self.attn.out.weight.copy_(out_weight)
|
240 |
+
self.attn.query.bias.copy_(query_bias)
|
241 |
+
self.attn.key.bias.copy_(key_bias)
|
242 |
+
self.attn.value.bias.copy_(value_bias)
|
243 |
+
self.attn.out.bias.copy_(out_bias)
|
244 |
+
|
245 |
+
mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t()
|
246 |
+
mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t()
|
247 |
+
mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t()
|
248 |
+
mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t()
|
249 |
+
|
250 |
+
self.ffn.fc1.weight.copy_(mlp_weight_0)
|
251 |
+
self.ffn.fc2.weight.copy_(mlp_weight_1)
|
252 |
+
self.ffn.fc1.bias.copy_(mlp_bias_0)
|
253 |
+
self.ffn.fc2.bias.copy_(mlp_bias_1)
|
254 |
+
|
255 |
+
self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")]))
|
256 |
+
self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")]))
|
257 |
+
self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")]))
|
258 |
+
self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")]))
|
259 |
+
|
260 |
+
|
261 |
+
class Encoder(nn.Module):
|
262 |
+
def __init__(self, config, vis):
|
263 |
+
super(Encoder, self).__init__()
|
264 |
+
self.vis = vis
|
265 |
+
self.layer = nn.ModuleList()
|
266 |
+
self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6)
|
267 |
+
for _ in range(config.transformer["num_layers"]):
|
268 |
+
layer = Block(config, vis)
|
269 |
+
self.layer.append(copy.deepcopy(layer))
|
270 |
+
|
271 |
+
def forward(self, hidden_states):
|
272 |
+
attn_weights = []
|
273 |
+
for layer_block in self.layer:
|
274 |
+
hidden_states, weights = layer_block(hidden_states)
|
275 |
+
if self.vis:
|
276 |
+
attn_weights.append(weights)
|
277 |
+
encoded = self.encoder_norm(hidden_states)
|
278 |
+
return encoded, attn_weights
|
279 |
+
|
280 |
+
|
281 |
+
class Transformer(nn.Module):
|
282 |
+
def __init__(self, config, img_size, vis):
|
283 |
+
super(Transformer, self).__init__()
|
284 |
+
self.embeddings = Embeddings(config, img_size=img_size)
|
285 |
+
self.encoder = Encoder(config, vis)
|
286 |
+
|
287 |
+
def forward(self, input_ids):
|
288 |
+
embedding_output = self.embeddings(input_ids)
|
289 |
+
encoded, attn_weights = self.encoder(embedding_output)
|
290 |
+
return encoded, attn_weights
|
291 |
+
|
292 |
+
|
293 |
+
class VisionTransformer(nn.Module):
|
294 |
+
def __init__(self, config, img_size=224, vis=False):
|
295 |
+
super(VisionTransformer, self).__init__()
|
296 |
+
#self.num_classes = num_classes
|
297 |
+
#self.classifier = config.classifier
|
298 |
+
self.embed_dim = config.hidden_size
|
299 |
+
|
300 |
+
self.transformer = Transformer(config, img_size, vis)
|
301 |
+
#self.head = Linear(config.hidden_size, num_classes)
|
302 |
+
|
303 |
+
def forward(self, x, labels=None, use_patches=False):
|
304 |
+
x, attn_weights = self.transformer(x)
|
305 |
+
#logits = self.head(x[:, 0])
|
306 |
+
|
307 |
+
#if labels is not None:
|
308 |
+
# loss_fct = CrossEntropyLoss()
|
309 |
+
# loss = loss_fct(logits.view(-1, self.num_classes), labels.view(-1))
|
310 |
+
# return loss
|
311 |
+
#else:
|
312 |
+
# return logits, attn_weights
|
313 |
+
|
314 |
+
if use_patches:
|
315 |
+
return x[:, 1:]
|
316 |
+
else:
|
317 |
+
return x[:, 0]
|
318 |
+
|
319 |
+
def load_from(self, weights):
|
320 |
+
with torch.no_grad():
|
321 |
+
#if self.zero_head:
|
322 |
+
# nn.init.zeros_(self.head.weight)
|
323 |
+
# nn.init.zeros_(self.head.bias)
|
324 |
+
#else:
|
325 |
+
# self.head.weight.copy_(np2th(weights["head/kernel"]).t())
|
326 |
+
# self.head.bias.copy_(np2th(weights["head/bias"]).t())
|
327 |
+
|
328 |
+
self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True))
|
329 |
+
self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"]))
|
330 |
+
self.transformer.embeddings.cls_token.copy_(np2th(weights["cls"]))
|
331 |
+
self.transformer.encoder.encoder_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"]))
|
332 |
+
self.transformer.encoder.encoder_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"]))
|
333 |
+
|
334 |
+
posemb = np2th(weights["Transformer/posembed_input/pos_embedding"])
|
335 |
+
posemb_new = self.transformer.embeddings.position_embeddings
|
336 |
+
if posemb.size() == posemb_new.size():
|
337 |
+
self.transformer.embeddings.position_embeddings.copy_(posemb)
|
338 |
+
else:
|
339 |
+
print("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size()))
|
340 |
+
ntok_new = posemb_new.size(1)
|
341 |
+
|
342 |
+
if self.classifier == "token":
|
343 |
+
posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:]
|
344 |
+
ntok_new -= 1
|
345 |
+
else:
|
346 |
+
posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
|
347 |
+
|
348 |
+
gs_old = int(np.sqrt(len(posemb_grid)))
|
349 |
+
gs_new = int(np.sqrt(ntok_new))
|
350 |
+
print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new))
|
351 |
+
posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1)
|
352 |
+
|
353 |
+
zoom = (gs_new / gs_old, gs_new / gs_old, 1)
|
354 |
+
posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1)
|
355 |
+
posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1)
|
356 |
+
posemb = np.concatenate([posemb_tok, posemb_grid], axis=1)
|
357 |
+
self.transformer.embeddings.position_embeddings.copy_(np2th(posemb))
|
358 |
+
|
359 |
+
for bname, block in self.transformer.encoder.named_children():
|
360 |
+
for uname, unit in block.named_children():
|
361 |
+
unit.load_from(weights, n_block=uname)
|
362 |
+
|
363 |
+
if self.transformer.embeddings.hybrid:
|
364 |
+
self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(weights["conv_root/kernel"], conv=True))
|
365 |
+
gn_weight = np2th(weights["gn_root/scale"]).view(-1)
|
366 |
+
gn_bias = np2th(weights["gn_root/bias"]).view(-1)
|
367 |
+
self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight)
|
368 |
+
self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias)
|
369 |
+
|
370 |
+
for bname, block in self.transformer.embeddings.hybrid_model.body.named_children():
|
371 |
+
for uname, unit in block.named_children():
|
372 |
+
unit.load_from(weights, n_block=bname, n_unit=uname)
|
requirements.txt
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==1.13.1
|
2 |
+
torchvision
|
3 |
+
numpy<2.0
|
4 |
+
tqdm
|
5 |
+
matplotlib
|
6 |
+
dotmap
|
7 |
+
Pillow
|
8 |
+
timm
|
9 |
+
ml-collections
|
10 |
+
ftfy
|
11 |
+
tensorboard
|
12 |
+
Google-Images-Search
|
13 |
+
semantic-version
|
14 |
+
pytz
|