vincentb25 commited on
Commit
2908104
·
1 Parent(s): 5199d23

Added all of the code for new space

Browse files
app.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import time
4
+ import random
5
+ import torch
6
+ import torchvision.transforms as transforms
7
+ import gradio as gr
8
+ import matplotlib.pyplot as plt
9
+
10
+ from models import get_model
11
+ from dotmap import DotMap
12
+ from PIL import Image
13
+
14
+ #os.environ['TERM'] = 'linux'
15
+ #os.environ['TERMINFO'] = '/etc/terminfo'
16
+
17
+ # args
18
+ args = DotMap()
19
+ args.deploy = 'vanilla'
20
+ args.arch = 'dino_small_patch16'
21
+ args.no_pretrain = True
22
+ args.resume = 'https://huggingface.co/hushell/pmf_dinosmall_lr1e-4/resolve/main/best_converted.pth'
23
+ args.api_key = 'AIzaSyAFkOGnXhy-2ZB0imDvNNqf2rHb98vR_qY'
24
+ args.cx = '06d75168141bc47f1'
25
+
26
+
27
+ # model
28
+ device = 'cpu' #torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+ model = get_model(args)
30
+ model.to(device)
31
+ checkpoint = torch.hub.load_state_dict_from_url(args.resume, map_location='cpu')
32
+ model.load_state_dict(checkpoint['model'], strict=True)
33
+
34
+
35
+ # image transforms
36
+ def test_transform():
37
+ def _convert_image_to_rgb(im):
38
+ return im.convert('RGB')
39
+
40
+ return transforms.Compose([
41
+ transforms.Resize(256),
42
+ transforms.CenterCrop(224),
43
+ _convert_image_to_rgb,
44
+ transforms.ToTensor(),
45
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
46
+ std=[0.229, 0.224, 0.225]),
47
+ ])
48
+
49
+ preprocess = test_transform()
50
+
51
+ @torch.no_grad()
52
+ def denormalize(x, mean, std):
53
+ # 3, H, W
54
+ t = x.clone()
55
+ t.mul_(std).add_(mean)
56
+ return torch.clamp(t, 0, 1)
57
+
58
+
59
+ # Gradio UI
60
+ def inference(query, class1_name="class1", support_imgs=None, class2_name="class2", support_imgs2=None):
61
+ '''
62
+ query: PIL image
63
+ labels: list of class names
64
+ '''
65
+
66
+
67
+ #first, open the images
68
+ support_imgs = [Image.open(img) for img in support_imgs]
69
+ support_imgs2 = [Image.open(img) for img in support_imgs2]
70
+
71
+
72
+ labels = [class1_name, class2_name]
73
+
74
+ supp_x = []
75
+ supp_y = []
76
+
77
+ for i, (class_name, support_img) in enumerate(zip([class1_name, class2_name], [support_imgs, support_imgs2])):
78
+ for img in support_img:
79
+ x_im = preprocess(img)
80
+ supp_x.append(x_im)
81
+ supp_y.append(i)
82
+
83
+
84
+ supp_x = torch.stack(supp_x, dim=0).unsqueeze(0).to(device) # (1, n_supp*n_labels, 3, H, W)
85
+ supp_y = torch.tensor(supp_y).long().unsqueeze(0).to(device) # (1, n_supp*n_labels)
86
+
87
+ query = preprocess(query).unsqueeze(0).unsqueeze(0).to(device) # (1, 3, H, W)
88
+
89
+ print(f"Shape of supp_x: {supp_x.shape}")
90
+ print(f"Shape of supp_y: {supp_y.shape}")
91
+ print(f"Shape of query: {query.shape}")
92
+
93
+
94
+
95
+ with torch.cuda.amp.autocast(True):
96
+ output = model(supp_x, supp_y, query) # (1, 1, n_labels)
97
+
98
+ probs = output.softmax(dim=-1).detach().cpu().numpy()
99
+
100
+ return {k: float(v) for k, v in zip(labels, probs[0, 0])}
101
+
102
+
103
+ # DEBUG
104
+ ##query = Image.open('../labrador-puppy.jpg')
105
+ #query = Image.open('/Users/hushell/Documents/Dan_tr.png')
106
+ ##labels = 'dog, cat'
107
+ #labels = 'girl, sussie'
108
+ #output = inference(query, labels, n_supp=2)
109
+ #print(output)
110
+
111
+
112
+ title = "P>M>F few-shot learning pipeline"
113
+ description = "Short description: We take a ViT-small backbone, which is pre-trained with DINO, and meta-trained on Meta-Dataset; for few-shot classification, we use a ProtoNet classifier. The demo can be viewed as zero-shot since the support set is built by searching images from Google. Note that you may need to play with GIS parameters to get good support examples. Besides, GIS is not very stable as search requests may fail for many reasons (e.g., number of requests reaches the limit of the day). This code is heavely inspired from the original HF space <a href='https://huggingface.co/spaces/hushell/pmf_with_gis' target='_blank'>here</a>"
114
+ article = "<p style='text-align: center'><a href='http://arxiv.org/abs/2204.07305' target='_blank'>Arxiv</a></p>"
115
+
116
+
117
+ gr.Interface(fn=inference,
118
+ inputs=[
119
+ gr.Image(label="Image to classify", type="pil"),
120
+ #gr.Textbox(lines=1, label="Class hypotheses:", placeholder="Enter class names separated by ','",),
121
+
122
+ gr.Textbox(lines=1, label="First class name :", placeholder="Enter first class name",),
123
+ gr.File(label="Drag or select one or more photos of the first class", file_types=["image"], file_count="multiple"),
124
+
125
+ gr.Textbox(lines=1, label="Second class name :", placeholder="Enter second class name",),
126
+ gr.File(label="Drag or select one or more photos of the second class", file_types=["image"], file_count="multiple"),
127
+ ],
128
+ theme="grass",
129
+ outputs=[
130
+ gr.Label(label="Predicted class probabilities"),
131
+ #gr.Image(type='pil', label="Support examples from Google image search"),
132
+ ],
133
+ title=title,
134
+ description=description,
135
+ article=article,
136
+ ).launch(debug=True)
models/__init__.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ #from timm.models import create_model
5
+ from .protonet import ProtoNet
6
+ from .deploy import ProtoNet_Finetune, ProtoNet_Auto_Finetune, ProtoNet_AdaTok, ProtoNet_AdaTok_EntMin
7
+
8
+
9
+ def get_backbone(args):
10
+ if args.arch == 'vit_base_patch16_224_in21k':
11
+ from .vit_google import VisionTransformer, CONFIGS
12
+
13
+ config = CONFIGS['ViT-B_16']
14
+ model = VisionTransformer(config, 224)
15
+
16
+ url = 'https://storage.googleapis.com/vit_models/imagenet21k/ViT-B_16.npz'
17
+ pretrained_weights = 'pretrained_ckpts/vit_base_patch16_224_in21k.npz'
18
+
19
+ if not os.path.exists(pretrained_weights):
20
+ try:
21
+ import wget
22
+ os.makedirs('pretrained_ckpts', exist_ok=True)
23
+ wget.download(url, pretrained_weights)
24
+ except:
25
+ print(f'Cannot download pretrained weights from {url}. Check if `pip install wget` works.')
26
+
27
+ model.load_from(np.load(pretrained_weights))
28
+ print('Pretrained weights found at {}'.format(pretrained_weights))
29
+
30
+ elif args.arch == 'dino_base_patch16':
31
+ from . import vision_transformer as vit
32
+
33
+ model = vit.__dict__['vit_base'](patch_size=16, num_classes=0)
34
+ url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
35
+ state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)
36
+
37
+ model.load_state_dict(state_dict, strict=True)
38
+ print('Pretrained weights found at {}'.format(url))
39
+
40
+ elif args.arch == 'deit_base_patch16':
41
+ from . import vision_transformer as vit
42
+
43
+ model = vit.__dict__['vit_base'](patch_size=16, num_classes=0)
44
+ url = "https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth"
45
+ state_dict = torch.hub.load_state_dict_from_url(url=url)["model"]
46
+
47
+ for k in ['head.weight', 'head.bias']:
48
+ if k in state_dict:
49
+ print(f"removing key {k} from pretrained checkpoint")
50
+ del state_dict[k]
51
+
52
+ model.load_state_dict(state_dict, strict=True)
53
+ print('Pretrained weights found at {}'.format(url))
54
+
55
+ elif args.arch == 'deit_small_patch16':
56
+ from . import vision_transformer as vit
57
+
58
+ model = vit.__dict__['vit_small'](patch_size=16, num_classes=0)
59
+ url = "https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth"
60
+ state_dict = torch.hub.load_state_dict_from_url(url=url)["model"]
61
+
62
+ for k in ['head.weight', 'head.bias']:
63
+ if k in state_dict:
64
+ print(f"removing key {k} from pretrained checkpoint")
65
+ del state_dict[k]
66
+
67
+ model.load_state_dict(state_dict, strict=True)
68
+ print('Pretrained weights found at {}'.format(url))
69
+
70
+ elif args.arch == 'dino_small_patch16':
71
+ from . import vision_transformer as vit
72
+
73
+ model = vit.__dict__['vit_small'](patch_size=16, num_classes=0)
74
+
75
+ if not args.no_pretrain:
76
+ url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
77
+ state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)
78
+
79
+ model.load_state_dict(state_dict, strict=True)
80
+ print('Pretrained weights found at {}'.format(url))
81
+
82
+ elif args.arch == 'beit_base_patch16_224_pt22k':
83
+ from .beit import default_pretrained_model
84
+ model = default_pretrained_model(args)
85
+ print('Pretrained BEiT loaded')
86
+
87
+ elif args.arch == 'clip_base_patch16_224':
88
+ from . import clip
89
+ model, _ = clip.load('ViT-B/16', 'cpu')
90
+
91
+ elif args.arch == 'clip_resnet50':
92
+ from . import clip
93
+ model, _ = clip.load('RN50', 'cpu')
94
+
95
+ elif args.arch == 'dino_resnet50':
96
+ from torchvision.models.resnet import resnet50
97
+
98
+ model = resnet50(pretrained=False)
99
+ model.fc = torch.nn.Identity()
100
+
101
+ if not args.no_pretrain:
102
+ state_dict = torch.hub.load_state_dict_from_url(
103
+ url="https://dl.fbaipublicfiles.com/dino/dino_resnet50_pretrain/dino_resnet50_pretrain.pth",
104
+ map_location="cpu",
105
+ )
106
+ model.load_state_dict(state_dict, strict=False)
107
+
108
+ elif args.arch == 'resnet50':
109
+ from torchvision.models.resnet import resnet50
110
+
111
+ pretrained = not args.no_pretrain
112
+ model = resnet50(pretrained=pretrained)
113
+ model.fc = torch.nn.Identity()
114
+
115
+ elif args.arch == 'resnet18':
116
+ from torchvision.models.resnet import resnet18
117
+
118
+ pretrained = not args.no_pretrain
119
+ model = resnet18(pretrained=pretrained)
120
+ model.fc = torch.nn.Identity()
121
+
122
+ elif args.arch == 'dino_xcit_medium_24_p16':
123
+ model = torch.hub.load('facebookresearch/xcit:main', 'xcit_medium_24_p16')
124
+ model.head = torch.nn.Identity()
125
+ state_dict = torch.hub.load_state_dict_from_url(
126
+ url="https://dl.fbaipublicfiles.com/dino/dino_xcit_medium_24_p16_pretrain/dino_xcit_medium_24_p16_pretrain.pth",
127
+ map_location="cpu",
128
+ )
129
+ model.load_state_dict(state_dict, strict=False)
130
+
131
+ elif args.arch == 'dino_xcit_medium_24_p8':
132
+ model = torch.hub.load('facebookresearch/dino:main', 'dino_xcit_medium_24_p8')
133
+
134
+ elif args.arch == 'simclrv2_resnet50':
135
+ import sys
136
+ sys.path.insert(
137
+ 0,
138
+ 'cog',
139
+ )
140
+ import model_utils
141
+
142
+ model_utils.MODELS_ROOT_DIR = 'cog/models'
143
+ ckpt_file = os.path.join(args.pretrained_checkpoint_path, 'pretrained_ckpts/simclrv2_resnet50.pth')
144
+ resnet, _ = model_utils.load_pretrained_backbone(args.arch, ckpt_file)
145
+
146
+ class Wrapper(torch.nn.Module):
147
+ def __init__(self, model):
148
+ super(Wrapper, self).__init__()
149
+ self.model = model
150
+
151
+ def forward(self, x):
152
+ return self.model(x, apply_fc=False)
153
+
154
+ model = Wrapper(resnet)
155
+
156
+ elif args.arch in ['mocov2_resnet50', 'swav_resnet50', 'barlow_resnet50']:
157
+ from torchvision.models.resnet import resnet50
158
+
159
+ model = resnet50(pretrained=False)
160
+ ckpt_file = os.path.join(args.pretrained_checkpoint_path, 'pretrained_ckpts_converted/{}.pth'.format(args.arch))
161
+ ckpt = torch.load(ckpt_file)
162
+
163
+ msg = model.load_state_dict(ckpt, strict=False)
164
+ assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
165
+
166
+ # remove the fully-connected layer
167
+ model.fc = torch.nn.Identity()
168
+
169
+ else:
170
+ raise ValueError(f'{args.arch} is not conisdered in the current code.')
171
+
172
+ return model
173
+
174
+
175
+ def get_model(args):
176
+ backbone = get_backbone(args)
177
+
178
+ if args.deploy == 'vanilla':
179
+ model = ProtoNet(backbone)
180
+ elif args.deploy == 'finetune':
181
+ model = ProtoNet_Finetune(backbone, args.ada_steps, args.ada_lr, args.aug_prob, args.aug_types)
182
+ elif args.deploy == 'finetune_autolr':
183
+ model = ProtoNet_Auto_Finetune(backbone, args.ada_steps, args.aug_prob, args.aug_types)
184
+ elif args.deploy == 'ada_tokens':
185
+ model = ProtoNet_AdaTok(backbone, args.num_adapters,
186
+ args.ada_steps, args.ada_lr)
187
+ elif args.deploy == 'ada_tokens_entmin':
188
+ model = ProtoNet_AdaTok_EntMin(backbone, args.num_adapters,
189
+ args.ada_steps, args.ada_lr)
190
+ else:
191
+ raise ValueError(f'deploy method {args.deploy} is not supported.')
192
+ return model
models/beit.py ADDED
@@ -0,0 +1,598 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beit
4
+ # Copyright (c) 2021 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # By Hangbo Bao
7
+ # Based on timm and DeiT code bases
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm
9
+ # https://github.com/facebookresearch/deit/
10
+ # https://github.com/facebookresearch/dino
11
+ # --------------------------------------------------------'
12
+ import math
13
+ from functools import partial
14
+ from scipy import interpolate
15
+
16
+ import numpy as np
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ from timm.models.layers import drop_path, to_2tuple, trunc_normal_
21
+ #from timm.models.registry import register_model
22
+
23
+
24
+ def _cfg(url='', **kwargs):
25
+ return {
26
+ 'url': url,
27
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
28
+ 'crop_pct': .9, 'interpolation': 'bicubic',
29
+ 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
30
+ **kwargs
31
+ }
32
+
33
+
34
+ class DropPath(nn.Module):
35
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
36
+ """
37
+ def __init__(self, drop_prob=None):
38
+ super(DropPath, self).__init__()
39
+ self.drop_prob = drop_prob
40
+
41
+ def forward(self, x):
42
+ return drop_path(x, self.drop_prob, self.training)
43
+
44
+ def extra_repr(self) -> str:
45
+ return 'p={}'.format(self.drop_prob)
46
+
47
+
48
+ class Mlp(nn.Module):
49
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
50
+ super().__init__()
51
+ out_features = out_features or in_features
52
+ hidden_features = hidden_features or in_features
53
+ self.fc1 = nn.Linear(in_features, hidden_features)
54
+ self.act = act_layer()
55
+ self.fc2 = nn.Linear(hidden_features, out_features)
56
+ self.drop = nn.Dropout(drop)
57
+
58
+ def forward(self, x):
59
+ x = self.fc1(x)
60
+ x = self.act(x)
61
+ # x = self.drop(x)
62
+ # commit this for the orignal BERT implement
63
+ x = self.fc2(x)
64
+ x = self.drop(x)
65
+ return x
66
+
67
+
68
+ class Attention(nn.Module):
69
+ def __init__(
70
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
71
+ proj_drop=0., window_size=None, attn_head_dim=None):
72
+ super().__init__()
73
+ self.num_heads = num_heads
74
+ head_dim = dim // num_heads
75
+ if attn_head_dim is not None:
76
+ head_dim = attn_head_dim
77
+ all_head_dim = head_dim * self.num_heads
78
+ self.scale = qk_scale or head_dim ** -0.5
79
+
80
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
81
+ if qkv_bias:
82
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
83
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
84
+ else:
85
+ self.q_bias = None
86
+ self.v_bias = None
87
+
88
+ if window_size:
89
+ self.window_size = window_size
90
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
91
+ self.relative_position_bias_table = nn.Parameter(
92
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
93
+ # cls to token & token 2 cls & cls to cls
94
+
95
+ # get pair-wise relative position index for each token inside the window
96
+ coords_h = torch.arange(window_size[0])
97
+ coords_w = torch.arange(window_size[1])
98
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
99
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
100
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
101
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
102
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
103
+ relative_coords[:, :, 1] += window_size[1] - 1
104
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
105
+ relative_position_index = \
106
+ torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
107
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
108
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
109
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
110
+ relative_position_index[0, 0] = self.num_relative_distance - 1
111
+
112
+ self.register_buffer("relative_position_index", relative_position_index)
113
+ else:
114
+ self.window_size = None
115
+ self.relative_position_bias_table = None
116
+ self.relative_position_index = None
117
+
118
+ self.attn_drop = nn.Dropout(attn_drop)
119
+ self.proj = nn.Linear(all_head_dim, dim)
120
+ self.proj_drop = nn.Dropout(proj_drop)
121
+
122
+ def forward(self, x, rel_pos_bias=None):
123
+ B, N, C = x.shape
124
+ qkv_bias = None
125
+ if self.q_bias is not None:
126
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
127
+ # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
128
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
129
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
130
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
131
+
132
+ q = q * self.scale
133
+ attn = (q @ k.transpose(-2, -1))
134
+
135
+ if self.relative_position_bias_table is not None:
136
+ relative_position_bias = \
137
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
138
+ self.window_size[0] * self.window_size[1] + 1,
139
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
140
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
141
+ attn = attn + relative_position_bias.unsqueeze(0)
142
+
143
+ if rel_pos_bias is not None:
144
+ attn = attn + rel_pos_bias
145
+
146
+ attn = attn.softmax(dim=-1)
147
+ attn = self.attn_drop(attn)
148
+
149
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
150
+ x = self.proj(x)
151
+ x = self.proj_drop(x)
152
+ return x
153
+
154
+
155
+ class Block(nn.Module):
156
+
157
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
158
+ drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
159
+ window_size=None, attn_head_dim=None):
160
+ super().__init__()
161
+ self.norm1 = norm_layer(dim)
162
+ self.attn = Attention(
163
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
164
+ attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)
165
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
166
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
167
+ self.norm2 = norm_layer(dim)
168
+ mlp_hidden_dim = int(dim * mlp_ratio)
169
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
170
+
171
+ if init_values > 0:
172
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
173
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
174
+ else:
175
+ self.gamma_1, self.gamma_2 = None, None
176
+
177
+ def forward(self, x, rel_pos_bias=None):
178
+ if self.gamma_1 is None:
179
+ x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
180
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
181
+ else:
182
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
183
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
184
+ return x
185
+
186
+
187
+ class PatchEmbed(nn.Module):
188
+ """ Image to Patch Embedding
189
+ """
190
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
191
+ super().__init__()
192
+ img_size = to_2tuple(img_size)
193
+ patch_size = to_2tuple(patch_size)
194
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
195
+ self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
196
+ self.img_size = img_size
197
+ self.patch_size = patch_size
198
+ self.num_patches = num_patches
199
+
200
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
201
+
202
+ def forward(self, x, **kwargs):
203
+ B, C, H, W = x.shape
204
+ # FIXME look at relaxing size constraints
205
+ assert H == self.img_size[0] and W == self.img_size[1], \
206
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
207
+ x = self.proj(x).flatten(2).transpose(1, 2)
208
+ return x
209
+
210
+
211
+ class RelativePositionBias(nn.Module):
212
+
213
+ def __init__(self, window_size, num_heads):
214
+ super().__init__()
215
+ self.window_size = window_size
216
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
217
+ self.relative_position_bias_table = nn.Parameter(
218
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
219
+ # cls to token & token 2 cls & cls to cls
220
+
221
+ # get pair-wise relative position index for each token inside the window
222
+ coords_h = torch.arange(window_size[0])
223
+ coords_w = torch.arange(window_size[1])
224
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
225
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
226
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
227
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
228
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
229
+ relative_coords[:, :, 1] += window_size[1] - 1
230
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
231
+ relative_position_index = \
232
+ torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
233
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
234
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
235
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
236
+ relative_position_index[0, 0] = self.num_relative_distance - 1
237
+
238
+ self.register_buffer("relative_position_index", relative_position_index)
239
+
240
+ # trunc_normal_(self.relative_position_bias_table, std=.02)
241
+
242
+ def forward(self):
243
+ relative_position_bias = \
244
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
245
+ self.window_size[0] * self.window_size[1] + 1,
246
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
247
+ return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
248
+
249
+
250
+ class VisionTransformer(nn.Module):
251
+ """ Vision Transformer with support for patch or hybrid CNN input stage
252
+ """
253
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
254
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
255
+ drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None,
256
+ use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False,
257
+ use_mean_pooling=True, init_scale=0.001):
258
+ super().__init__()
259
+ self.num_classes = num_classes
260
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
261
+
262
+ self.patch_embed = PatchEmbed(
263
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
264
+ num_patches = self.patch_embed.num_patches
265
+
266
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
267
+ # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
268
+ if use_abs_pos_emb:
269
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
270
+ else:
271
+ self.pos_embed = None
272
+ self.pos_drop = nn.Dropout(p=drop_rate)
273
+
274
+ if use_shared_rel_pos_bias:
275
+ self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
276
+ else:
277
+ self.rel_pos_bias = None
278
+
279
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
280
+ self.use_rel_pos_bias = use_rel_pos_bias
281
+ self.blocks = nn.ModuleList([
282
+ Block(
283
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
284
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
285
+ init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)
286
+ for i in range(depth)])
287
+ self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
288
+ self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
289
+ self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
290
+
291
+ if self.pos_embed is not None:
292
+ trunc_normal_(self.pos_embed, std=.02)
293
+ trunc_normal_(self.cls_token, std=.02)
294
+ # trunc_normal_(self.mask_token, std=.02)
295
+ self.apply(self._init_weights)
296
+ self.fix_init_weight()
297
+
298
+ if num_classes > 0:
299
+ trunc_normal_(self.head.weight, std=.02)
300
+ self.head.weight.data.mul_(init_scale)
301
+ self.head.bias.data.mul_(init_scale)
302
+
303
+ def fix_init_weight(self):
304
+ def rescale(param, layer_id):
305
+ param.div_(math.sqrt(2.0 * layer_id))
306
+
307
+ for layer_id, layer in enumerate(self.blocks):
308
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
309
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
310
+
311
+ def _init_weights(self, m):
312
+ if isinstance(m, nn.Linear):
313
+ trunc_normal_(m.weight, std=.02)
314
+ if isinstance(m, nn.Linear) and m.bias is not None:
315
+ nn.init.constant_(m.bias, 0)
316
+ elif isinstance(m, nn.LayerNorm):
317
+ nn.init.constant_(m.bias, 0)
318
+ nn.init.constant_(m.weight, 1.0)
319
+
320
+ def get_num_layers(self):
321
+ return len(self.blocks)
322
+
323
+ @torch.jit.ignore
324
+ def no_weight_decay(self):
325
+ return {'pos_embed', 'cls_token'}
326
+
327
+ def get_classifier(self):
328
+ return self.head
329
+
330
+ def reset_classifier(self, num_classes, global_pool=''):
331
+ self.num_classes = num_classes
332
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
333
+
334
+ def forward_features(self, x):
335
+ x = self.patch_embed(x)
336
+ batch_size, seq_len, _ = x.size()
337
+
338
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
339
+ x = torch.cat((cls_tokens, x), dim=1)
340
+ if self.pos_embed is not None:
341
+ x = x + self.pos_embed
342
+ x = self.pos_drop(x)
343
+
344
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
345
+ for blk in self.blocks:
346
+ x = blk(x, rel_pos_bias=rel_pos_bias)
347
+
348
+ x = self.norm(x)
349
+ if self.fc_norm is not None:
350
+ t = x[:, 1:, :]
351
+ return self.fc_norm(t.mean(1))
352
+ else:
353
+ return x[:, 0]
354
+
355
+ def forward(self, x):
356
+ x = self.forward_features(x)
357
+ x = self.head(x)
358
+ return x
359
+
360
+
361
+ #@register_model
362
+ def beit_base_patch16_224(pretrained=False, **kwargs):
363
+ model = VisionTransformer(
364
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
365
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
366
+ model.default_cfg = _cfg()
367
+ return model
368
+
369
+
370
+ #@register_model
371
+ def beit_base_patch16_384(pretrained=False, **kwargs):
372
+ model = VisionTransformer(
373
+ img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
374
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
375
+ model.default_cfg = _cfg()
376
+ return model
377
+
378
+
379
+ #@register_model
380
+ def beit_large_patch16_224(pretrained=False, **kwargs):
381
+ model = VisionTransformer(
382
+ patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
383
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
384
+ model.default_cfg = _cfg()
385
+ return model
386
+
387
+
388
+ #@register_model
389
+ def beit_large_patch16_384(pretrained=False, **kwargs):
390
+ model = VisionTransformer(
391
+ img_size=384, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
392
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
393
+ model.default_cfg = _cfg()
394
+ return model
395
+
396
+
397
+ #@register_model
398
+ def beit_large_patch16_512(pretrained=False, **kwargs):
399
+ model = VisionTransformer(
400
+ img_size=512, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
401
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
402
+ model.default_cfg = _cfg()
403
+ return model
404
+
405
+
406
+ def load_state_dict(model, state_dict, prefix='', ignore_missing="relative_position_index"):
407
+ missing_keys = []
408
+ unexpected_keys = []
409
+ error_msgs = []
410
+ # copy state_dict so _load_from_state_dict can modify it
411
+ metadata = getattr(state_dict, '_metadata', None)
412
+ state_dict = state_dict.copy()
413
+ if metadata is not None:
414
+ state_dict._metadata = metadata
415
+
416
+ def _load(module, prefix=''):
417
+ local_metadata = {} if metadata is None else metadata.get(
418
+ prefix[:-1], {})
419
+ module._load_from_state_dict(
420
+ state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
421
+ for name, child in module._modules.items():
422
+ if child is not None:
423
+ _load(child, prefix + name + '.')
424
+
425
+ _load(model, prefix=prefix)
426
+
427
+ warn_missing_keys = []
428
+ ignore_missing_keys = []
429
+ for key in missing_keys:
430
+ keep_flag = True
431
+ for ignore_key in ignore_missing.split('|'):
432
+ if ignore_key in key:
433
+ keep_flag = False
434
+ break
435
+ if keep_flag:
436
+ warn_missing_keys.append(key)
437
+ else:
438
+ ignore_missing_keys.append(key)
439
+
440
+ missing_keys = warn_missing_keys
441
+
442
+ if len(missing_keys) > 0:
443
+ print("Weights of {} not initialized from pretrained model: {}".format(
444
+ model.__class__.__name__, missing_keys))
445
+ if len(unexpected_keys) > 0:
446
+ print("Weights from pretrained model not used in {}: {}".format(
447
+ model.__class__.__name__, unexpected_keys))
448
+ if len(ignore_missing_keys) > 0:
449
+ print("Ignored weights of {} not initialized from pretrained model: {}".format(
450
+ model.__class__.__name__, ignore_missing_keys))
451
+ if len(error_msgs) > 0:
452
+ print('\n'.join(error_msgs))
453
+
454
+
455
+ def default_pretrained_model(args):
456
+ model = beit_base_patch16_224(
457
+ pretrained=False,
458
+ img_size=args.image_size,
459
+ num_classes=0,
460
+ drop_rate=0.,
461
+ drop_path_rate=0.1,
462
+ attn_drop_rate=0.,
463
+ #drop_block_rate=None,
464
+ use_mean_pooling=True,
465
+ init_scale=0.001,
466
+ use_rel_pos_bias=True,
467
+ use_abs_pos_emb=False,
468
+ init_values=0.1,
469
+ )
470
+
471
+ #url = 'https://unilm.blob.core.windows.net/beit/beit_base_patch16_224_pt22k.pth'
472
+ url = 'https://unilm.blob.core.windows.net/beit/beit_base_patch16_224_pt22k_ft22k.pth'
473
+
474
+ checkpoint = torch.hub.load_state_dict_from_url(
475
+ url, map_location='cpu', check_hash=True)
476
+ print('Pretrained weights found at {}'.format(url))
477
+
478
+ # select key
479
+ checkpoint_model = None
480
+ for model_key in ['model', 'module']:
481
+ if model_key in checkpoint:
482
+ checkpoint_model = checkpoint[model_key]
483
+ print("Load state_dict by model_key = %s" % model_key)
484
+ break
485
+ if checkpoint_model is None:
486
+ checkpoint_model = checkpoint
487
+
488
+ # remove head
489
+ state_dict = model.state_dict()
490
+ for k in ['head.weight', 'head.bias']:
491
+ #if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
492
+ if k in checkpoint_model:
493
+ print(f"Removing key {k} from pretrained checkpoint")
494
+ del checkpoint_model[k]
495
+
496
+ # resize rel_pos_bias
497
+ if model.use_rel_pos_bias and "rel_pos_bias.relative_position_bias_table" in checkpoint_model:
498
+ print("Expand the shared relative position embedding to each transformer block. ")
499
+ num_layers = model.get_num_layers()
500
+ rel_pos_bias = checkpoint_model["rel_pos_bias.relative_position_bias_table"]
501
+ for i in range(num_layers):
502
+ checkpoint_model["blocks.%d.attn.relative_position_bias_table" % i] = rel_pos_bias.clone()
503
+
504
+ checkpoint_model.pop("rel_pos_bias.relative_position_bias_table")
505
+
506
+ all_keys = list(checkpoint_model.keys())
507
+ for key in all_keys:
508
+ if "relative_position_index" in key:
509
+ checkpoint_model.pop(key)
510
+
511
+ if "relative_position_bias_table" in key:
512
+ rel_pos_bias = checkpoint_model[key]
513
+ src_num_pos, num_attn_heads = rel_pos_bias.size()
514
+ dst_num_pos, _ = model.state_dict()[key].size()
515
+ dst_patch_shape = model.patch_embed.patch_shape
516
+ if dst_patch_shape[0] != dst_patch_shape[1]:
517
+ raise NotImplementedError()
518
+ num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1)
519
+ src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
520
+ dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5)
521
+ if src_size != dst_size:
522
+ print("Position interpolate for %s from %dx%d to %dx%d" % (
523
+ key, src_size, src_size, dst_size, dst_size))
524
+ extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
525
+ rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
526
+
527
+ def geometric_progression(a, r, n):
528
+ return a * (1.0 - r ** n) / (1.0 - r)
529
+
530
+ left, right = 1.01, 1.5
531
+ while right - left > 1e-6:
532
+ q = (left + right) / 2.0
533
+ gp = geometric_progression(1, q, src_size // 2)
534
+ if gp > dst_size // 2:
535
+ right = q
536
+ else:
537
+ left = q
538
+
539
+ # if q > 1.090307:
540
+ # q = 1.090307
541
+
542
+ dis = []
543
+ cur = 1
544
+ for i in range(src_size // 2):
545
+ dis.append(cur)
546
+ cur += q ** (i + 1)
547
+
548
+ r_ids = [-_ for _ in reversed(dis)]
549
+
550
+ x = r_ids + [0] + dis
551
+ y = r_ids + [0] + dis
552
+
553
+ t = dst_size // 2.0
554
+ dx = np.arange(-t, t + 0.1, 1.0)
555
+ dy = np.arange(-t, t + 0.1, 1.0)
556
+
557
+ print("Original positions = %s" % str(x))
558
+ print("Target positions = %s" % str(dx))
559
+
560
+ all_rel_pos_bias = []
561
+
562
+ for i in range(num_attn_heads):
563
+ z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()
564
+ f = interpolate.interp2d(x, y, z, kind='cubic')
565
+ all_rel_pos_bias.append(
566
+ torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device))
567
+
568
+ rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
569
+
570
+ new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
571
+ checkpoint_model[key] = new_rel_pos_bias
572
+
573
+ # interpolate position embedding
574
+ if 'pos_embed' in checkpoint_model:
575
+ pos_embed_checkpoint = checkpoint_model['pos_embed']
576
+ embedding_size = pos_embed_checkpoint.shape[-1]
577
+ num_patches = model.patch_embed.num_patches
578
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
579
+ # height (== width) for the checkpoint position embedding
580
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
581
+ # height (== width) for the new position embedding
582
+ new_size = int(num_patches ** 0.5)
583
+ # class_token and dist_token are kept unchanged
584
+ if orig_size != new_size:
585
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
586
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
587
+ # only the position tokens are interpolated
588
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
589
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
590
+ pos_tokens = torch.nn.functional.interpolate(
591
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
592
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
593
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
594
+ checkpoint_model['pos_embed'] = new_pos_embed
595
+
596
+ load_state_dict(model, checkpoint_model)
597
+ return model
598
+
models/clip/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .clip import *
models/clip/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
models/clip/clip.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+ import urllib
4
+ import warnings
5
+ from typing import Any, Union, List
6
+
7
+ import torch
8
+ from PIL import Image
9
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
10
+ from tqdm import tqdm
11
+
12
+ from .model import build_model, build_vision_model
13
+ from .simple_tokenizer import SimpleTokenizer as _Tokenizer
14
+
15
+ try:
16
+ from torchvision.transforms import InterpolationMode
17
+ BICUBIC = InterpolationMode.BICUBIC
18
+ except ImportError:
19
+ BICUBIC = Image.BICUBIC
20
+
21
+
22
+ if torch.__version__.split(".") < ["1", "7", "1"]:
23
+ warnings.warn("PyTorch version 1.7.1 or higher is recommended")
24
+
25
+
26
+ __all__ = ["available_models", "load", "tokenize"]
27
+ _tokenizer = _Tokenizer()
28
+
29
+ _MODELS = {
30
+ "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
31
+ "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
32
+ "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
33
+ "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
34
+ "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
35
+ "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
36
+ }
37
+
38
+
39
+ def _download(url: str, root: str):
40
+ os.makedirs(root, exist_ok=True)
41
+ filename = os.path.basename(url)
42
+
43
+ expected_sha256 = url.split("/")[-2]
44
+ download_target = os.path.join(root, filename)
45
+
46
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
47
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
48
+
49
+ if os.path.isfile(download_target):
50
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
51
+ return download_target
52
+ else:
53
+ warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
54
+
55
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
56
+ with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
57
+ while True:
58
+ buffer = source.read(8192)
59
+ if not buffer:
60
+ break
61
+
62
+ output.write(buffer)
63
+ loop.update(len(buffer))
64
+
65
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
66
+ raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
67
+
68
+ return download_target
69
+
70
+
71
+ def _convert_image_to_rgb(image):
72
+ return image.convert("RGB")
73
+
74
+
75
+ def _transform(n_px):
76
+ return Compose([
77
+ Resize(n_px, interpolation=BICUBIC),
78
+ CenterCrop(n_px),
79
+ _convert_image_to_rgb,
80
+ ToTensor(),
81
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
82
+ ])
83
+
84
+
85
+ def available_models() -> List[str]:
86
+ """Returns the names of available CLIP models"""
87
+ return list(_MODELS.keys())
88
+
89
+
90
+ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
91
+ """Load a CLIP model
92
+
93
+ Parameters
94
+ ----------
95
+ name : str
96
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
97
+
98
+ device : Union[str, torch.device]
99
+ The device to put the loaded model
100
+
101
+ jit : bool
102
+ Whether to load the optimized JIT model or more hackable non-JIT model (default).
103
+
104
+ download_root: str
105
+ path to download the model files; by default, it uses "~/.cache/clip"
106
+
107
+ Returns
108
+ -------
109
+ model : torch.nn.Module
110
+ The CLIP model
111
+
112
+ preprocess : Callable[[PIL.Image], torch.Tensor]
113
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
114
+ """
115
+ if name in _MODELS:
116
+ model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
117
+ elif os.path.isfile(name):
118
+ model_path = name
119
+ else:
120
+ raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
121
+
122
+ try:
123
+ # loading JIT archive
124
+ model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
125
+ state_dict = None
126
+ except RuntimeError:
127
+ # loading saved state dict
128
+ if jit:
129
+ warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
130
+ jit = False
131
+ state_dict = torch.load(model_path, map_location="cpu")
132
+
133
+ if not jit:
134
+ #model = build_model(state_dict or model.state_dict()).to(device)
135
+ model = build_vision_model(state_dict or model.state_dict()).to(device)
136
+ if str(device) == "cpu":
137
+ model.float()
138
+ return model, _transform(model.visual.input_resolution)
139
+
140
+ # patch the device names
141
+ device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
142
+ device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
143
+
144
+ def patch_device(module):
145
+ try:
146
+ graphs = [module.graph] if hasattr(module, "graph") else []
147
+ except RuntimeError:
148
+ graphs = []
149
+
150
+ if hasattr(module, "forward1"):
151
+ graphs.append(module.forward1.graph)
152
+
153
+ for graph in graphs:
154
+ for node in graph.findAllNodes("prim::Constant"):
155
+ if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
156
+ node.copyAttributes(device_node)
157
+
158
+ model.apply(patch_device)
159
+ patch_device(model.encode_image)
160
+ patch_device(model.encode_text)
161
+
162
+ # patch dtype to float32 on CPU
163
+ if str(device) == "cpu":
164
+ float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
165
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
166
+ float_node = float_input.node()
167
+
168
+ def patch_float(module):
169
+ try:
170
+ graphs = [module.graph] if hasattr(module, "graph") else []
171
+ except RuntimeError:
172
+ graphs = []
173
+
174
+ if hasattr(module, "forward1"):
175
+ graphs.append(module.forward1.graph)
176
+
177
+ for graph in graphs:
178
+ for node in graph.findAllNodes("aten::to"):
179
+ inputs = list(node.inputs())
180
+ for i in [1, 2]: # dtype can be the second or third argument to aten::to()
181
+ if inputs[i].node()["value"] == 5:
182
+ inputs[i].node().copyAttributes(float_node)
183
+
184
+ model.apply(patch_float)
185
+ patch_float(model.encode_image)
186
+ patch_float(model.encode_text)
187
+
188
+ model.float()
189
+
190
+ return model, _transform(model.input_resolution.item())
191
+
192
+
193
+ def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor:
194
+ """
195
+ Returns the tokenized representation of given input string(s)
196
+
197
+ Parameters
198
+ ----------
199
+ texts : Union[str, List[str]]
200
+ An input string or a list of input strings to tokenize
201
+
202
+ context_length : int
203
+ The context length to use; all CLIP models use 77 as the context length
204
+
205
+ truncate: bool
206
+ Whether to truncate the text in case its encoding is longer than the context length
207
+
208
+ Returns
209
+ -------
210
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
211
+ """
212
+ if isinstance(texts, str):
213
+ texts = [texts]
214
+
215
+ sot_token = _tokenizer.encoder["<|startoftext|>"]
216
+ eot_token = _tokenizer.encoder["<|endoftext|>"]
217
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
218
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
219
+
220
+ for i, tokens in enumerate(all_tokens):
221
+ if len(tokens) > context_length:
222
+ if truncate:
223
+ tokens = tokens[:context_length]
224
+ tokens[-1] = eot_token
225
+ else:
226
+ raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
227
+ result[i, :len(tokens)] = torch.tensor(tokens)
228
+
229
+ return result
models/clip/model.py ADDED
@@ -0,0 +1,577 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from typing import Tuple, Union
3
+
4
+ import math
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch import nn
9
+
10
+
11
+ class Bottleneck(nn.Module):
12
+ expansion = 4
13
+
14
+ def __init__(self, inplanes, planes, stride=1):
15
+ super().__init__()
16
+
17
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
18
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
19
+ self.bn1 = nn.BatchNorm2d(planes)
20
+
21
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
22
+ self.bn2 = nn.BatchNorm2d(planes)
23
+
24
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
25
+
26
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
27
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
28
+
29
+ self.relu = nn.ReLU(inplace=True)
30
+ self.downsample = None
31
+ self.stride = stride
32
+
33
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
34
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
35
+ self.downsample = nn.Sequential(OrderedDict([
36
+ ("-1", nn.AvgPool2d(stride)),
37
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
38
+ ("1", nn.BatchNorm2d(planes * self.expansion))
39
+ ]))
40
+
41
+ def forward(self, x: torch.Tensor):
42
+ identity = x
43
+
44
+ out = self.relu(self.bn1(self.conv1(x)))
45
+ out = self.relu(self.bn2(self.conv2(out)))
46
+ out = self.avgpool(out)
47
+ out = self.bn3(self.conv3(out))
48
+
49
+ if self.downsample is not None:
50
+ identity = self.downsample(x)
51
+
52
+ out += identity
53
+ out = self.relu(out)
54
+ return out
55
+
56
+
57
+ class AttentionPool2d(nn.Module):
58
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
59
+ super().__init__()
60
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
61
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
62
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
63
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
64
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
65
+ self.num_heads = num_heads
66
+
67
+ def interpolate_pos_encoding(self, x, h0, w0):
68
+ assert w0 == h0, f'{self} only support square images!'
69
+ pos_embed = self.positional_embedding.unsqueeze(1).to(x.dtype)
70
+ npatch = x.shape[0] - 1
71
+ N = pos_embed.shape[0] - 1
72
+ if npatch == N:
73
+ return pos_embed
74
+ class_pos_embed = pos_embed[0]
75
+ patch_pos_embed = pos_embed[1:]
76
+ dim = x.shape[-1]
77
+ # we add a small number to avoid floating point error in the interpolation
78
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
79
+ w0, h0 = w0 + 0.1, h0 + 0.1
80
+ patch_pos_embed = nn.functional.interpolate(
81
+ patch_pos_embed.reshape(int(math.sqrt(N)), int(math.sqrt(N)), 1, dim).permute(2, 3, 0, 1),
82
+ scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
83
+ mode='bicubic',
84
+ align_corners=False,
85
+ recompute_scale_factor=False
86
+ )
87
+ assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
88
+ patch_pos_embed = patch_pos_embed.permute(2, 3, 0, 1).view(-1, 1, dim)
89
+ return torch.cat((class_pos_embed.unsqueeze(1), patch_pos_embed), dim=0)
90
+
91
+ def forward(self, x):
92
+ B, C, H, W = x.shape
93
+
94
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
95
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
96
+ x = x + self.interpolate_pos_encoding(x, H, W) # (HW+1)NC
97
+ x, _ = F.multi_head_attention_forward(
98
+ query=x, key=x, value=x,
99
+ embed_dim_to_check=x.shape[-1],
100
+ num_heads=self.num_heads,
101
+ q_proj_weight=self.q_proj.weight,
102
+ k_proj_weight=self.k_proj.weight,
103
+ v_proj_weight=self.v_proj.weight,
104
+ in_proj_weight=None,
105
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
106
+ bias_k=None,
107
+ bias_v=None,
108
+ add_zero_attn=False,
109
+ dropout_p=0,
110
+ out_proj_weight=self.c_proj.weight,
111
+ out_proj_bias=self.c_proj.bias,
112
+ use_separate_proj_weight=True,
113
+ training=self.training,
114
+ need_weights=False
115
+ )
116
+
117
+ return x[0]
118
+
119
+
120
+ class ModifiedResNet(nn.Module):
121
+ """
122
+ A ResNet class that is similar to torchvision's but contains the following changes:
123
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
124
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
125
+ - The final pooling layer is a QKV attention instead of an average pool
126
+ """
127
+
128
+ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
129
+ super().__init__()
130
+ self.output_dim = output_dim
131
+ self.input_resolution = input_resolution
132
+
133
+ # the 3-layer stem
134
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
135
+ self.bn1 = nn.BatchNorm2d(width // 2)
136
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
137
+ self.bn2 = nn.BatchNorm2d(width // 2)
138
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
139
+ self.bn3 = nn.BatchNorm2d(width)
140
+ self.avgpool = nn.AvgPool2d(2)
141
+ self.relu = nn.ReLU(inplace=True)
142
+
143
+ # residual layers
144
+ self._inplanes = width # this is a *mutable* variable used during construction
145
+ self.layer1 = self._make_layer(width, layers[0])
146
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
147
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
148
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
149
+
150
+ embed_dim = width * 32 # the ResNet feature dimension
151
+ self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
152
+ #self.gap = nn.AdaptiveAvgPool2d((1, 1))
153
+
154
+ def _make_layer(self, planes, blocks, stride=1):
155
+ layers = [Bottleneck(self._inplanes, planes, stride)]
156
+
157
+ self._inplanes = planes * Bottleneck.expansion
158
+ for _ in range(1, blocks):
159
+ layers.append(Bottleneck(self._inplanes, planes))
160
+
161
+ return nn.Sequential(*layers)
162
+
163
+ def forward(self, x):
164
+ def stem(x):
165
+ for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
166
+ x = self.relu(bn(conv(x)))
167
+ x = self.avgpool(x)
168
+ return x
169
+
170
+ x = x.type(self.conv1.weight.dtype)
171
+ x = stem(x)
172
+ x = self.layer1(x)
173
+ x = self.layer2(x)
174
+ x = self.layer3(x)
175
+ x = self.layer4(x)
176
+ x = self.attnpool(x)
177
+ #x = self.gap(x)
178
+
179
+ return x
180
+
181
+
182
+ class LayerNorm(nn.LayerNorm):
183
+ """Subclass torch's LayerNorm to handle fp16."""
184
+
185
+ def forward(self, x: torch.Tensor):
186
+ orig_type = x.dtype
187
+ ret = super().forward(x.type(torch.float32))
188
+ return ret.type(orig_type)
189
+
190
+
191
+ class QuickGELU(nn.Module):
192
+ def forward(self, x: torch.Tensor):
193
+ return x * torch.sigmoid(1.702 * x)
194
+
195
+
196
+ class ResidualAttentionBlock(nn.Module):
197
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
198
+ super().__init__()
199
+
200
+ self.attn = nn.MultiheadAttention(d_model, n_head)
201
+ self.ln_1 = LayerNorm(d_model)
202
+ self.mlp = nn.Sequential(OrderedDict([
203
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
204
+ ("gelu", QuickGELU()),
205
+ ("c_proj", nn.Linear(d_model * 4, d_model))
206
+ ]))
207
+ self.ln_2 = LayerNorm(d_model)
208
+ self.attn_mask = attn_mask
209
+
210
+ def attention(self, x: torch.Tensor):
211
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
212
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
213
+
214
+ def forward(self, x: torch.Tensor):
215
+ x = x + self.attention(self.ln_1(x))
216
+ x = x + self.mlp(self.ln_2(x))
217
+ return x
218
+
219
+
220
+ class Transformer(nn.Module):
221
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
222
+ super().__init__()
223
+ self.width = width
224
+ self.layers = layers
225
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
226
+
227
+ def forward(self, x: torch.Tensor):
228
+ return self.resblocks(x)
229
+
230
+
231
+ class VisionTransformer(nn.Module):
232
+ def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
233
+ super().__init__()
234
+ self.input_resolution = input_resolution
235
+ self.output_dim = output_dim
236
+ self.patch_size = patch_size
237
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
238
+
239
+ scale = width ** -0.5
240
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
241
+ self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
242
+ self.ln_pre = LayerNorm(width)
243
+
244
+ self.transformer = Transformer(width, layers, heads)
245
+
246
+ self.ln_post = LayerNorm(width)
247
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
248
+
249
+ def interpolate_pos_encoding(self, x, h, w):
250
+ pos_embed = self.positional_embedding.unsqueeze(0).to(x.dtype)
251
+ npatch = x.shape[1] - 1
252
+ N = pos_embed.shape[1] - 1
253
+ if npatch == N and w == h:
254
+ return pos_embed
255
+ class_pos_embed = pos_embed[:, 0]
256
+ patch_pos_embed = pos_embed[:, 1:]
257
+ dim = x.shape[-1]
258
+ w0 = w // self.patch_size
259
+ h0 = h // self.patch_size
260
+ # we add a small number to avoid floating point error in the interpolation
261
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
262
+ w0, h0 = w0 + 0.1, h0 + 0.1
263
+ patch_pos_embed = nn.functional.interpolate(
264
+ patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
265
+ scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
266
+ mode='bicubic',
267
+ align_corners=False,
268
+ recompute_scale_factor=False
269
+ )
270
+ assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
271
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
272
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
273
+
274
+ def forward(self, x: torch.Tensor):
275
+ B, C, H, W = x.shape
276
+
277
+ x = self.conv1(x) # shape = [*, width, grid, grid]
278
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
279
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
280
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
281
+ x = x + self.interpolate_pos_encoding(x, H, W)
282
+ x = self.ln_pre(x)
283
+
284
+ x = x.permute(1, 0, 2) # NLD -> LND
285
+ x = self.transformer(x)
286
+ x = x.permute(1, 0, 2) # LND -> NLD
287
+
288
+ x = self.ln_post(x[:, 0, :])
289
+
290
+ if self.proj is not None:
291
+ x = x @ self.proj
292
+
293
+ return x
294
+
295
+
296
+ class VisionBackbone(nn.Module):
297
+ def __init__(self,
298
+ embed_dim: int,
299
+ # vision
300
+ image_resolution: int,
301
+ vision_layers: Union[Tuple[int, int, int, int], int],
302
+ vision_width: int,
303
+ vision_patch_size: int,
304
+ ):
305
+ super().__init__()
306
+
307
+ if isinstance(vision_layers, (tuple, list)):
308
+ vision_heads = vision_width * 32 // 64
309
+ self.visual = ModifiedResNet(
310
+ layers=vision_layers,
311
+ output_dim=embed_dim,
312
+ heads=vision_heads,
313
+ input_resolution=image_resolution,
314
+ width=vision_width
315
+ )
316
+ else:
317
+ vision_heads = vision_width // 64
318
+ self.visual = VisionTransformer(
319
+ input_resolution=image_resolution,
320
+ patch_size=vision_patch_size,
321
+ width=vision_width,
322
+ layers=vision_layers,
323
+ heads=vision_heads,
324
+ output_dim=embed_dim
325
+ )
326
+
327
+ #self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
328
+
329
+ self.initialize_parameters()
330
+
331
+ def initialize_parameters(self):
332
+ if isinstance(self.visual, ModifiedResNet):
333
+ if self.visual.attnpool is not None:
334
+ std = self.visual.attnpool.c_proj.in_features ** -0.5
335
+ nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
336
+ nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
337
+ nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
338
+ nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
339
+
340
+ for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
341
+ for name, param in resnet_block.named_parameters():
342
+ if name.endswith("bn3.weight"):
343
+ nn.init.zeros_(param)
344
+
345
+ @property
346
+ def dtype(self):
347
+ return self.visual.conv1.weight.dtype
348
+
349
+ def forward(self, image):
350
+ return self.visual(image.type(self.dtype))
351
+
352
+
353
+ class CLIP(nn.Module):
354
+ def __init__(self,
355
+ embed_dim: int,
356
+ # vision
357
+ image_resolution: int,
358
+ vision_layers: Union[Tuple[int, int, int, int], int],
359
+ vision_width: int,
360
+ vision_patch_size: int,
361
+ # text
362
+ context_length: int,
363
+ vocab_size: int,
364
+ transformer_width: int,
365
+ transformer_heads: int,
366
+ transformer_layers: int
367
+ ):
368
+ super().__init__()
369
+
370
+ self.context_length = context_length
371
+
372
+ if isinstance(vision_layers, (tuple, list)):
373
+ vision_heads = vision_width * 32 // 64
374
+ self.visual = ModifiedResNet(
375
+ layers=vision_layers,
376
+ output_dim=embed_dim,
377
+ heads=vision_heads,
378
+ input_resolution=image_resolution,
379
+ width=vision_width
380
+ )
381
+ else:
382
+ vision_heads = vision_width // 64
383
+ self.visual = VisionTransformer(
384
+ input_resolution=image_resolution,
385
+ patch_size=vision_patch_size,
386
+ width=vision_width,
387
+ layers=vision_layers,
388
+ heads=vision_heads,
389
+ output_dim=embed_dim
390
+ )
391
+
392
+ self.transformer = Transformer(
393
+ width=transformer_width,
394
+ layers=transformer_layers,
395
+ heads=transformer_heads,
396
+ attn_mask=self.build_attention_mask()
397
+ )
398
+
399
+ self.vocab_size = vocab_size
400
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
401
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
402
+ self.ln_final = LayerNorm(transformer_width)
403
+
404
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
405
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
406
+
407
+ self.initialize_parameters()
408
+
409
+ def initialize_parameters(self):
410
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
411
+ nn.init.normal_(self.positional_embedding, std=0.01)
412
+
413
+ if isinstance(self.visual, ModifiedResNet):
414
+ if self.visual.attnpool is not None:
415
+ std = self.visual.attnpool.c_proj.in_features ** -0.5
416
+ nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
417
+ nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
418
+ nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
419
+ nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
420
+
421
+ for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
422
+ for name, param in resnet_block.named_parameters():
423
+ if name.endswith("bn3.weight"):
424
+ nn.init.zeros_(param)
425
+
426
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
427
+ attn_std = self.transformer.width ** -0.5
428
+ fc_std = (2 * self.transformer.width) ** -0.5
429
+ for block in self.transformer.resblocks:
430
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
431
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
432
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
433
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
434
+
435
+ if self.text_projection is not None:
436
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
437
+
438
+ def build_attention_mask(self):
439
+ # lazily create causal attention mask, with full attention between the vision tokens
440
+ # pytorch uses additive attention mask; fill with -inf
441
+ mask = torch.empty(self.context_length, self.context_length)
442
+ mask.fill_(float("-inf"))
443
+ mask.triu_(1) # zero out the lower diagonal
444
+ return mask
445
+
446
+ @property
447
+ def dtype(self):
448
+ return self.visual.conv1.weight.dtype
449
+
450
+ def encode_image(self, image):
451
+ return self.visual(image.type(self.dtype))
452
+
453
+ def encode_text(self, text):
454
+ x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
455
+
456
+ x = x + self.positional_embedding.type(self.dtype)
457
+ x = x.permute(1, 0, 2) # NLD -> LND
458
+ x = self.transformer(x)
459
+ x = x.permute(1, 0, 2) # LND -> NLD
460
+ x = self.ln_final(x).type(self.dtype)
461
+
462
+ # x.shape = [batch_size, n_ctx, transformer.width]
463
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
464
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
465
+
466
+ return x
467
+
468
+ def forward(self, image, text):
469
+ image_features = self.encode_image(image)
470
+ text_features = self.encode_text(text)
471
+
472
+ # normalized features
473
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
474
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
475
+
476
+ # cosine similarity as logits
477
+ logit_scale = self.logit_scale.exp()
478
+ logits_per_image = logit_scale * image_features @ text_features.t()
479
+ logits_per_text = logits_per_image.t()
480
+
481
+ # shape = [global_batch_size, global_batch_size]
482
+ return logits_per_image, logits_per_text
483
+
484
+
485
+ def convert_weights(model: nn.Module):
486
+ """Convert applicable model parameters to fp16"""
487
+
488
+ def _convert_weights_to_fp16(l):
489
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
490
+ l.weight.data = l.weight.data.half()
491
+ if l.bias is not None:
492
+ l.bias.data = l.bias.data.half()
493
+
494
+ if isinstance(l, nn.MultiheadAttention):
495
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
496
+ tensor = getattr(l, attr)
497
+ if tensor is not None:
498
+ tensor.data = tensor.data.half()
499
+
500
+ for name in ["text_projection", "proj"]:
501
+ if hasattr(l, name):
502
+ attr = getattr(l, name)
503
+ if attr is not None:
504
+ attr.data = attr.data.half()
505
+
506
+ model.apply(_convert_weights_to_fp16)
507
+
508
+
509
+ def build_model(state_dict: dict):
510
+ vit = "visual.proj" in state_dict
511
+
512
+ if vit:
513
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
514
+ vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
515
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
516
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
517
+ image_resolution = vision_patch_size * grid_size
518
+ else:
519
+ counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
520
+ vision_layers = tuple(counts)
521
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
522
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
523
+ vision_patch_size = None
524
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
525
+ image_resolution = output_width * 32
526
+
527
+ embed_dim = state_dict["text_projection"].shape[1]
528
+ context_length = state_dict["positional_embedding"].shape[0]
529
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
530
+ transformer_width = state_dict["ln_final.weight"].shape[0]
531
+ transformer_heads = transformer_width // 64
532
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
533
+
534
+ model = CLIP(
535
+ embed_dim,
536
+ image_resolution, vision_layers, vision_width, vision_patch_size,
537
+ context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
538
+ )
539
+
540
+ for key in ["input_resolution", "context_length", "vocab_size"]:
541
+ if key in state_dict:
542
+ del state_dict[key]
543
+
544
+ convert_weights(model)
545
+ model.load_state_dict(state_dict)
546
+ return model.eval()
547
+
548
+
549
+ def build_vision_model(state_dict: dict):
550
+ vit = "visual.proj" in state_dict
551
+
552
+ if vit:
553
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
554
+ vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
555
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
556
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
557
+ image_resolution = vision_patch_size * grid_size
558
+ else:
559
+ counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
560
+ vision_layers = tuple(counts)
561
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
562
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
563
+ vision_patch_size = None
564
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
565
+ image_resolution = output_width * 32
566
+
567
+ embed_dim = state_dict["text_projection"].shape[1]
568
+
569
+ model = VisionBackbone(
570
+ embed_dim,
571
+ image_resolution, vision_layers, vision_width, vision_patch_size,
572
+ )
573
+
574
+ convert_weights(model)
575
+ msg = model.load_state_dict(state_dict, strict=False)
576
+ print(f'clip.build_vision_model: pretrained weights loaded with message: {msg}')
577
+ return model.eval()
models/clip/simple_tokenizer.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gzip
2
+ import html
3
+ import os
4
+ from functools import lru_cache
5
+
6
+ import ftfy
7
+ import regex as re
8
+
9
+
10
+ @lru_cache()
11
+ def default_bpe():
12
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13
+
14
+
15
+ @lru_cache()
16
+ def bytes_to_unicode():
17
+ """
18
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
19
+ The reversible bpe codes work on unicode strings.
20
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
23
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
25
+ """
26
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27
+ cs = bs[:]
28
+ n = 0
29
+ for b in range(2**8):
30
+ if b not in bs:
31
+ bs.append(b)
32
+ cs.append(2**8+n)
33
+ n += 1
34
+ cs = [chr(n) for n in cs]
35
+ return dict(zip(bs, cs))
36
+
37
+
38
+ def get_pairs(word):
39
+ """Return set of symbol pairs in a word.
40
+ Word is represented as tuple of symbols (symbols being variable-length strings).
41
+ """
42
+ pairs = set()
43
+ prev_char = word[0]
44
+ for char in word[1:]:
45
+ pairs.add((prev_char, char))
46
+ prev_char = char
47
+ return pairs
48
+
49
+
50
+ def basic_clean(text):
51
+ text = ftfy.fix_text(text)
52
+ text = html.unescape(html.unescape(text))
53
+ return text.strip()
54
+
55
+
56
+ def whitespace_clean(text):
57
+ text = re.sub(r'\s+', ' ', text)
58
+ text = text.strip()
59
+ return text
60
+
61
+
62
+ class SimpleTokenizer(object):
63
+ def __init__(self, bpe_path: str = default_bpe()):
64
+ self.byte_encoder = bytes_to_unicode()
65
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67
+ merges = merges[1:49152-256-2+1]
68
+ merges = [tuple(merge.split()) for merge in merges]
69
+ vocab = list(bytes_to_unicode().values())
70
+ vocab = vocab + [v+'</w>' for v in vocab]
71
+ for merge in merges:
72
+ vocab.append(''.join(merge))
73
+ vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74
+ self.encoder = dict(zip(vocab, range(len(vocab))))
75
+ self.decoder = {v: k for k, v in self.encoder.items()}
76
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
77
+ self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78
+ self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79
+
80
+ def bpe(self, token):
81
+ if token in self.cache:
82
+ return self.cache[token]
83
+ word = tuple(token[:-1]) + ( token[-1] + '</w>',)
84
+ pairs = get_pairs(word)
85
+
86
+ if not pairs:
87
+ return token+'</w>'
88
+
89
+ while True:
90
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
91
+ if bigram not in self.bpe_ranks:
92
+ break
93
+ first, second = bigram
94
+ new_word = []
95
+ i = 0
96
+ while i < len(word):
97
+ try:
98
+ j = word.index(first, i)
99
+ new_word.extend(word[i:j])
100
+ i = j
101
+ except:
102
+ new_word.extend(word[i:])
103
+ break
104
+
105
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
106
+ new_word.append(first+second)
107
+ i += 2
108
+ else:
109
+ new_word.append(word[i])
110
+ i += 1
111
+ new_word = tuple(new_word)
112
+ word = new_word
113
+ if len(word) == 1:
114
+ break
115
+ else:
116
+ pairs = get_pairs(word)
117
+ word = ' '.join(word)
118
+ self.cache[token] = word
119
+ return word
120
+
121
+ def encode(self, text):
122
+ bpe_tokens = []
123
+ text = whitespace_clean(basic_clean(text)).lower()
124
+ for token in re.findall(self.pat, text):
125
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
126
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
127
+ return bpe_tokens
128
+
129
+ def decode(self, tokens):
130
+ text = ''.join([self.decoder[token] for token in tokens])
131
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
132
+ return text
models/deploy.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch.distributed as dist
6
+ from copy import deepcopy
7
+ from tqdm import tqdm
8
+ from timm.utils import accuracy
9
+ from .protonet import ProtoNet
10
+ from .utils import trunc_normal_, DiffAugment
11
+
12
+
13
+ def is_dist_avail_and_initialized():
14
+ if not dist.is_available():
15
+ return False
16
+ if not dist.is_initialized():
17
+ return False
18
+ return True
19
+
20
+
21
+ def get_rank():
22
+ if not is_dist_avail_and_initialized():
23
+ return 0
24
+ return dist.get_rank()
25
+
26
+
27
+ def is_main_process():
28
+ return get_rank() == 0
29
+
30
+
31
+ @torch.jit.script
32
+ def entropy_loss(x):
33
+ return torch.sum(-F.softmax(x, 1) * F.log_softmax(x, 1), 1).mean()
34
+
35
+
36
+ def unique_indices(x):
37
+ """
38
+ Ref: https://github.com/rusty1s/pytorch_unique
39
+ """
40
+ unique, inverse = torch.unique(x, sorted=True, return_inverse=True)
41
+ perm = torch.arange(inverse.size(0), dtype=inverse.dtype, device=inverse.device)
42
+ inverse, perm = inverse.flip([0]), perm.flip([0])
43
+ perm = inverse.new_empty(unique.size(0)).scatter_(0, inverse, perm)
44
+ return unique, perm
45
+
46
+
47
+ class ProtoNet_Auto_Finetune(ProtoNet):
48
+ def __init__(self, backbone, num_iters=50, aug_prob=0.9,
49
+ aug_types=['color', 'translation'], lr_lst=[0.01, 0.001, 0.0001]):
50
+ super().__init__(backbone)
51
+ self.num_iters = num_iters
52
+ self.lr_lst = lr_lst
53
+ self.aug_types = aug_types
54
+ self.aug_prob = aug_prob
55
+
56
+ state_dict = backbone.state_dict()
57
+ self.backbone_state = deepcopy(state_dict)
58
+
59
+ def forward(self, supp_x, supp_y, qry_x):
60
+ """
61
+ supp_x.shape = [B, nSupp, C, H, W]
62
+ supp_y.shape = [B, nSupp]
63
+ qry_x.shape = [B, nQry, C, H, W]
64
+ """
65
+ B, nSupp, C, H, W = supp_x.shape
66
+ num_classes = supp_y.max() + 1 # NOTE: assume B==1
67
+ device = qry_x.device
68
+
69
+ criterion = nn.CrossEntropyLoss()
70
+ supp_x = supp_x.view(-1, C, H, W)
71
+ qry_x = qry_x.view(-1, C, H, W)
72
+ supp_y_1hot = F.one_hot(supp_y, num_classes).transpose(1, 2) # B, nC, nSupp
73
+ supp_y = supp_y.view(-1)
74
+
75
+ def single_step(z, mode=True, x=None, y=None, y_1hot=None):
76
+ '''
77
+ z = Aug(supp_x) or qry_x
78
+ global vars: supp_x, supp_y, supp_y_1hot
79
+ '''
80
+ with torch.set_grad_enabled(mode):
81
+ # recalculate prototypes from supp_x with updated backbone
82
+ proto_f = self.backbone.forward(x).unsqueeze(0)
83
+
84
+ if y_1hot is None:
85
+ prototypes = proto_f
86
+ else:
87
+ prototypes = torch.bmm(y_1hot.float(), proto_f) # B, nC, d
88
+ prototypes = prototypes / y_1hot.sum(dim=2, keepdim=True) # NOTE: may div 0
89
+
90
+ # compute feature for z
91
+ feat = self.backbone.forward(z)
92
+ feat = feat.view(B, z.shape[0], -1) # B, nQry, d
93
+
94
+ # classification
95
+ logits = self.cos_classifier(prototypes, feat) # B, nQry, nC
96
+ loss = None
97
+
98
+ if mode: # if enable grad, compute loss
99
+ loss = criterion(logits.view(len(y), -1), y)
100
+
101
+ return logits, loss
102
+
103
+ # load trained weights
104
+ self.backbone.load_state_dict(self.backbone_state, strict=True)
105
+
106
+ #zz = DiffAugment(supp_x, ["color", "offset", "offset_h", "offset_v", "translation", "cutout"], 1., detach=True)
107
+ proto_y, proto_i = unique_indices(supp_y)
108
+ proto_x = supp_x[proto_i]
109
+ zz_i = np.setdiff1d(range(len(supp_x)), proto_i.cpu().numpy())
110
+ zz_x = supp_x[zz_i]
111
+ zz_y = supp_y[zz_i]
112
+
113
+ best_lr = 0
114
+ max_acc1 = 0
115
+
116
+ if len(zz_y) > 0:
117
+ # eval non-finetuned weights (lr=0)
118
+ logits, _ = single_step(zz_x, False, x=proto_x)
119
+ max_acc1 = accuracy(logits.view(len(zz_y), -1), zz_y, topk=(1,))[0]
120
+ print(f'## *lr = 0: acc1 = {max_acc1}\n')
121
+
122
+ for lr in self.lr_lst:
123
+ # create optimizer
124
+ opt = torch.optim.Adam(self.backbone.parameters(),
125
+ lr=lr,
126
+ betas=(0.9, 0.999),
127
+ weight_decay=0.)
128
+
129
+ # main loop
130
+ _num_iters = 50
131
+ pbar = tqdm(range(_num_iters)) if is_main_process() else range(_num_iters)
132
+ for i in pbar:
133
+ opt.zero_grad()
134
+ z = DiffAugment(proto_x, self.aug_types, self.aug_prob, detach=True)
135
+ _, loss = single_step(z, True, x=proto_x, y=proto_y)
136
+ loss.backward()
137
+ opt.step()
138
+ if is_main_process():
139
+ pbar.set_description(f' << lr = {lr}: loss = {loss.item()}')
140
+
141
+ logits, _ = single_step(zz_x, False, x=proto_x)
142
+ acc1 = accuracy(logits.view(len(zz_y), -1), zz_y, topk=(1,))[0]
143
+ print(f'## *lr = {lr}: acc1 = {acc1}\n')
144
+
145
+ if acc1 > max_acc1:
146
+ max_acc1 = acc1
147
+ best_lr = lr
148
+
149
+ # reset backbone state
150
+ self.backbone.load_state_dict(self.backbone_state, strict=True)
151
+
152
+ print(f'***Best lr = {best_lr} with acc1 = {max_acc1}.\nStart final loop...\n')
153
+
154
+ # create optimizer
155
+ opt = torch.optim.Adam(self.backbone.parameters(),
156
+ lr=best_lr,
157
+ betas=(0.9, 0.999),
158
+ weight_decay=0.)
159
+
160
+ # main loop
161
+ pbar = tqdm(range(self.num_iters)) if is_main_process() else range(self.num_iters)
162
+ for i in pbar:
163
+ opt.zero_grad()
164
+ z = DiffAugment(supp_x, self.aug_types, self.aug_prob, detach=True)
165
+ _, loss = single_step(z, True, x=supp_x, y=supp_y, y_1hot=supp_y_1hot)
166
+ loss.backward()
167
+ opt.step()
168
+ if is_main_process():
169
+ pbar.set_description(f' >> lr = {best_lr}: loss = {loss.item()}')
170
+
171
+ logits, _ = single_step(qry_x, False, x=supp_x, y_1hot=supp_y_1hot) # supp_x has to pair with y_1hot
172
+
173
+ return logits
174
+
175
+
176
+ class ProtoNet_Finetune(ProtoNet):
177
+ def __init__(self, backbone, num_iters=50, lr=5e-2, aug_prob=0.9,
178
+ aug_types=['color', 'translation']):
179
+ super().__init__(backbone)
180
+ self.num_iters = num_iters
181
+ self.lr = lr
182
+ self.aug_types = aug_types
183
+ self.aug_prob = aug_prob
184
+
185
+ def load_state_dict(self, state_dict, strict=True):
186
+ super().load_state_dict(state_dict, strict)
187
+
188
+ state_dict = self.backbone.state_dict()
189
+ self.backbone_state = deepcopy(state_dict)
190
+
191
+ def forward(self, supp_x, supp_y, x):
192
+ """
193
+ supp_x.shape = [B, nSupp, C, H, W]
194
+ supp_y.shape = [B, nSupp]
195
+ x.shape = [B, nQry, C, H, W]
196
+ """
197
+ # reset backbone state
198
+ self.backbone.load_state_dict(self.backbone_state, strict=True)
199
+
200
+ if self.lr == 0:
201
+ return super().forward(supp_x, supp_y, x)
202
+
203
+ B, nSupp, C, H, W = supp_x.shape
204
+ num_classes = supp_y.max() + 1 # NOTE: assume B==1
205
+ device = x.device
206
+
207
+ criterion = nn.CrossEntropyLoss()
208
+ supp_x = supp_x.view(-1, C, H, W)
209
+ x = x.view(-1, C, H, W)
210
+ supp_y_1hot = F.one_hot(supp_y, num_classes).transpose(1, 2) # B, nC, nSupp
211
+ supp_y = supp_y.view(-1)
212
+
213
+ # create optimizer
214
+ opt = torch.optim.Adam(self.backbone.parameters(),
215
+ lr=self.lr,
216
+ betas=(0.9, 0.999),
217
+ weight_decay=0.)
218
+
219
+ def single_step(z, mode=True):
220
+ '''
221
+ z = Aug(supp_x) or x
222
+ '''
223
+ with torch.set_grad_enabled(mode):
224
+ # recalculate prototypes from supp_x with updated backbone
225
+ supp_f = self.backbone.forward(supp_x)
226
+ supp_f = supp_f.view(B, nSupp, -1)
227
+ prototypes = torch.bmm(supp_y_1hot.float(), supp_f) # B, nC, d
228
+ prototypes = prototypes / supp_y_1hot.sum(dim=2, keepdim=True) # NOTE: may div 0
229
+
230
+ # compute feature for z
231
+ feat = self.backbone.forward(z)
232
+ feat = feat.view(B, z.shape[0], -1) # B, nQry, d
233
+
234
+ # classification
235
+ logits = self.cos_classifier(prototypes, feat) # B, nQry, nC
236
+ loss = None
237
+
238
+ if mode: # if enable grad, compute loss
239
+ loss = criterion(logits.view(B*nSupp, -1), supp_y)
240
+
241
+ return logits, loss
242
+
243
+ # main loop
244
+ pbar = tqdm(range(self.num_iters)) if is_main_process() else range(self.num_iters)
245
+ for i in pbar:
246
+ opt.zero_grad()
247
+ z = DiffAugment(supp_x, self.aug_types, self.aug_prob, detach=True)
248
+ _, loss = single_step(z, True)
249
+ loss.backward()
250
+ opt.step()
251
+ if is_main_process():
252
+ pbar.set_description(f'lr{self.lr}, nSupp{nSupp}, nQry{x.shape[0]}: loss = {loss.item()}')
253
+
254
+ logits, _ = single_step(x, False)
255
+ return logits
256
+
257
+
258
+ class ProtoNet_AdaTok(ProtoNet):
259
+ def __init__(self, backbone, num_adapters=1, num_iters=50, lr=5e-2, momentum=0.9, weight_decay=0.):
260
+ super().__init__(backbone)
261
+ self.num_adapters = num_adapters
262
+ self.num_iters = num_iters
263
+ self.lr = lr
264
+ self.momentum = momentum
265
+ self.weight_decay = weight_decay
266
+
267
+ def forward(self, supp_x, supp_y, x):
268
+ """
269
+ supp_x.shape = [B, nSupp, C, H, W]
270
+ supp_y.shape = [B, nSupp]
271
+ x.shape = [B, nQry, C, H, W]
272
+ """
273
+ B, nSupp, C, H, W = supp_x.shape
274
+ nQry = x.shape[1]
275
+ num_classes = supp_y.max() + 1 # NOTE: assume B==1
276
+ device = x.device
277
+
278
+ criterion = nn.CrossEntropyLoss()
279
+ supp_x = supp_x.view(-1, C, H, W)
280
+ x = x.view(-1, C, H, W)
281
+ supp_y_1hot = F.one_hot(supp_y, num_classes).transpose(1, 2) # B, nC, nSupp
282
+ supp_y = supp_y.view(-1)
283
+
284
+ # prepare adapter tokens
285
+ ada_tokens = torch.zeros(1, self.num_adapters, self.backbone.embed_dim, device=device)
286
+ trunc_normal_(ada_tokens, std=.02)
287
+ ada_tokens = ada_tokens.detach().requires_grad_()
288
+ #optimizer = torch.optim.SGD([ada_tokens],
289
+ optimizer = torch.optim.Adadelta([ada_tokens],
290
+ lr=self.lr,
291
+ #momentum=self.momentum,
292
+ weight_decay=self.weight_decay)
293
+
294
+ def single_step(mode=True):
295
+ with torch.set_grad_enabled(mode):
296
+ supp_f = self.backbone.forward(supp_x, ada_tokens)
297
+ supp_f = supp_f.view(B, nSupp, -1)
298
+
299
+ # B, nC, nSupp x B, nSupp, d = B, nC, d
300
+ prototypes = torch.bmm(supp_y_1hot.float(), supp_f)
301
+ prototypes = prototypes / supp_y_1hot.sum(dim=2, keepdim=True) # NOTE: may div 0
302
+
303
+ if mode == False: # no grad
304
+ feat = self.backbone.forward(x, ada_tokens)
305
+ feat = feat.view(B, nQry, -1) # B, nQry, d
306
+
307
+ logits = self.cos_classifier(prototypes, feat) # B, nQry, nC
308
+ loss = None
309
+ else:
310
+ with torch.enable_grad():
311
+ logits = self.cos_classifier(prototypes, supp_f) # B, nQry, nC
312
+ loss = criterion(logits.view(B*nSupp, -1), supp_y)
313
+
314
+ return logits, loss
315
+
316
+ pbar = tqdm(range(self.num_iters)) if is_main_process() else range(self.num_iters)
317
+ for i in pbar:
318
+ optimizer.zero_grad()
319
+ _, loss = single_step(True)
320
+ loss.backward()
321
+ optimizer.step()
322
+ if is_main_process():
323
+ pbar.set_description(f'loss = {loss.item()}')
324
+
325
+ logits, _ = single_step(False)
326
+ return logits
327
+
328
+
329
+ class ProtoNet_AdaTok_EntMin(ProtoNet):
330
+ def __init__(self, backbone, num_adapters=1, num_iters=50, lr=5e-3, momentum=0.9, weight_decay=0.):
331
+ super().__init__(backbone)
332
+ self.num_adapters = num_adapters
333
+ self.num_iters = num_iters
334
+ self.lr = lr
335
+ self.momentum = momentum
336
+ self.weight_decay = weight_decay
337
+
338
+ def forward(self, supp_x, supp_y, x):
339
+ """
340
+ supp_x.shape = [B, nSupp, C, H, W]
341
+ supp_y.shape = [B, nSupp]
342
+ x.shape = [B, nQry, C, H, W]
343
+ """
344
+ B, nSupp, C, H, W = supp_x.shape
345
+ num_classes = supp_y.max() + 1 # NOTE: assume B==1
346
+ device = x.device
347
+
348
+ criterion = entropy_loss
349
+ supp_x = supp_x.view(-1, C, H, W)
350
+ x = x.view(-1, C, H, W)
351
+ supp_y_1hot = F.one_hot(supp_y, num_classes).transpose(1, 2) # B, nC, nSupp
352
+
353
+ # adapter tokens
354
+ ada_tokens = torch.zeros(1, self.num_adapters, self.backbone.embed_dim, device=device)
355
+ trunc_normal_(ada_tokens, std=.02)
356
+ ada_tokens = ada_tokens.detach().requires_grad_()
357
+ optimizer = torch.optim.SGD([ada_tokens],
358
+ lr=self.lr,
359
+ momentum=self.momentum,
360
+ weight_decay=self.weight_decay)
361
+
362
+ def single_step(mode=True):
363
+ with torch.set_grad_enabled(mode):
364
+ supp_f = self.backbone.forward(supp_x, ada_tokens)
365
+ supp_f = supp_f.view(B, nSupp, -1)
366
+
367
+ # B, nC, nSupp x B, nSupp, d = B, nC, d
368
+ prototypes = torch.bmm(supp_y_1hot.float(), supp_f)
369
+ prototypes = prototypes / supp_y_1hot.sum(dim=2, keepdim=True) # NOTE: may div 0
370
+
371
+ feat = self.backbone.forward(x, ada_tokens)
372
+ feat = feat.view(B, x.shape[1], -1) # B, nQry, d
373
+
374
+ logits = self.cos_classifier(prototypes, feat) # B, nQry, nC
375
+ loss = criterion(logits.view(-1, num_classes))
376
+
377
+ return logits, loss
378
+
379
+ pbar = tqdm(range(self.num_iters)) if is_main_process() else range(self.num_iters)
380
+ for i in pbar:
381
+ optimizer.zero_grad()
382
+ _, loss = single_step(True)
383
+ loss.backward()
384
+ optimizer.step()
385
+ if is_main_process():
386
+ pbar.set_description(f'loss = {loss.item()}')
387
+
388
+ logits, _ = single_step(False)
389
+ return logits
models/protonet.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class ProtoNet(nn.Module):
7
+ def __init__(self, backbone):
8
+ super().__init__()
9
+
10
+ # bias & scale of cosine classifier
11
+ self.bias = nn.Parameter(torch.FloatTensor(1).fill_(0), requires_grad=True)
12
+ self.scale_cls = nn.Parameter(torch.FloatTensor(1).fill_(10), requires_grad=True)
13
+
14
+ # backbone
15
+ self.backbone = backbone
16
+
17
+ def cos_classifier(self, w, f):
18
+ """
19
+ w.shape = B, nC, d
20
+ f.shape = B, M, d
21
+ """
22
+ f = F.normalize(f, p=2, dim=f.dim()-1, eps=1e-12)
23
+ w = F.normalize(w, p=2, dim=w.dim()-1, eps=1e-12)
24
+
25
+ cls_scores = f @ w.transpose(1, 2) # B, M, nC
26
+ cls_scores = self.scale_cls * (cls_scores + self.bias)
27
+ return cls_scores
28
+
29
+ def forward(self, supp_x, supp_y, x):
30
+ """
31
+ supp_x.shape = [B, nSupp, C, H, W]
32
+ supp_y.shape = [B, nSupp]
33
+ x.shape = [B, nQry, C, H, W]
34
+ """
35
+ num_classes = supp_y.max() + 1 # NOTE: assume B==1
36
+
37
+ B, nSupp, C, H, W = supp_x.shape
38
+ supp_f = self.backbone.forward(supp_x.view(-1, C, H, W))
39
+ supp_f = supp_f.view(B, nSupp, -1)
40
+
41
+ supp_y_1hot = F.one_hot(supp_y, num_classes).transpose(1, 2) # B, nC, nSupp
42
+
43
+ # B, nC, nSupp x B, nSupp, d = B, nC, d
44
+ prototypes = torch.bmm(supp_y_1hot.float(), supp_f)
45
+ prototypes = prototypes / supp_y_1hot.sum(dim=2, keepdim=True) # NOTE: may div 0 if some classes got 0 images
46
+
47
+ feat = self.backbone.forward(x.view(-1, C, H, W))
48
+ feat = feat.view(B, x.shape[1], -1) # B, nQry, d
49
+
50
+ logits = self.cos_classifier(prototypes, feat) # B, nQry, nC
51
+ return logits
models/resnet_v2.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Lint as: python3
16
+ """Bottleneck ResNet v2 with GroupNorm and Weight Standardization."""
17
+ import math
18
+
19
+ from os.path import join as pjoin
20
+
21
+ from collections import OrderedDict # pylint: disable=g-importing-member
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+ import torch.nn.functional as F
26
+
27
+
28
+ def np2th(weights, conv=False):
29
+ """Possibly convert HWIO to OIHW."""
30
+ if conv:
31
+ weights = weights.transpose([3, 2, 0, 1])
32
+ return torch.from_numpy(weights)
33
+
34
+
35
+ class StdConv2d(nn.Conv2d):
36
+
37
+ def forward(self, x):
38
+ w = self.weight
39
+ v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)
40
+ w = (w - m) / torch.sqrt(v + 1e-5)
41
+ return F.conv2d(x, w, self.bias, self.stride, self.padding,
42
+ self.dilation, self.groups)
43
+
44
+
45
+ def conv3x3(cin, cout, stride=1, groups=1, bias=False):
46
+ return StdConv2d(cin, cout, kernel_size=3, stride=stride,
47
+ padding=1, bias=bias, groups=groups)
48
+
49
+
50
+ def conv1x1(cin, cout, stride=1, bias=False):
51
+ return StdConv2d(cin, cout, kernel_size=1, stride=stride,
52
+ padding=0, bias=bias)
53
+
54
+
55
+ class PreActBottleneck(nn.Module):
56
+ """Pre-activation (v2) bottleneck block.
57
+ """
58
+
59
+ def __init__(self, cin, cout=None, cmid=None, stride=1):
60
+ super().__init__()
61
+ cout = cout or cin
62
+ cmid = cmid or cout//4
63
+
64
+ self.gn1 = nn.GroupNorm(32, cmid, eps=1e-6)
65
+ self.conv1 = conv1x1(cin, cmid, bias=False)
66
+ self.gn2 = nn.GroupNorm(32, cmid, eps=1e-6)
67
+ self.conv2 = conv3x3(cmid, cmid, stride, bias=False) # Original code has it on conv1!!
68
+ self.gn3 = nn.GroupNorm(32, cout, eps=1e-6)
69
+ self.conv3 = conv1x1(cmid, cout, bias=False)
70
+ self.relu = nn.ReLU(inplace=True)
71
+
72
+ if (stride != 1 or cin != cout):
73
+ # Projection also with pre-activation according to paper.
74
+ self.downsample = conv1x1(cin, cout, stride, bias=False)
75
+ self.gn_proj = nn.GroupNorm(cout, cout)
76
+
77
+ def forward(self, x):
78
+
79
+ # Residual branch
80
+ residual = x
81
+ if hasattr(self, 'downsample'):
82
+ residual = self.downsample(x)
83
+ residual = self.gn_proj(residual)
84
+
85
+ # Unit's branch
86
+ y = self.relu(self.gn1(self.conv1(x)))
87
+ y = self.relu(self.gn2(self.conv2(y)))
88
+ y = self.gn3(self.conv3(y))
89
+
90
+ y = self.relu(residual + y)
91
+ return y
92
+
93
+ def load_from(self, weights, n_block, n_unit):
94
+ conv1_weight = np2th(weights[pjoin(n_block, n_unit, "conv1/kernel")], conv=True)
95
+ conv2_weight = np2th(weights[pjoin(n_block, n_unit, "conv2/kernel")], conv=True)
96
+ conv3_weight = np2th(weights[pjoin(n_block, n_unit, "conv3/kernel")], conv=True)
97
+
98
+ gn1_weight = np2th(weights[pjoin(n_block, n_unit, "gn1/scale")])
99
+ gn1_bias = np2th(weights[pjoin(n_block, n_unit, "gn1/bias")])
100
+
101
+ gn2_weight = np2th(weights[pjoin(n_block, n_unit, "gn2/scale")])
102
+ gn2_bias = np2th(weights[pjoin(n_block, n_unit, "gn2/bias")])
103
+
104
+ gn3_weight = np2th(weights[pjoin(n_block, n_unit, "gn3/scale")])
105
+ gn3_bias = np2th(weights[pjoin(n_block, n_unit, "gn3/bias")])
106
+
107
+ self.conv1.weight.copy_(conv1_weight)
108
+ self.conv2.weight.copy_(conv2_weight)
109
+ self.conv3.weight.copy_(conv3_weight)
110
+
111
+ self.gn1.weight.copy_(gn1_weight.view(-1))
112
+ self.gn1.bias.copy_(gn1_bias.view(-1))
113
+
114
+ self.gn2.weight.copy_(gn2_weight.view(-1))
115
+ self.gn2.bias.copy_(gn2_bias.view(-1))
116
+
117
+ self.gn3.weight.copy_(gn3_weight.view(-1))
118
+ self.gn3.bias.copy_(gn3_bias.view(-1))
119
+
120
+ if hasattr(self, 'downsample'):
121
+ proj_conv_weight = np2th(weights[pjoin(n_block, n_unit, "conv_proj/kernel")], conv=True)
122
+ proj_gn_weight = np2th(weights[pjoin(n_block, n_unit, "gn_proj/scale")])
123
+ proj_gn_bias = np2th(weights[pjoin(n_block, n_unit, "gn_proj/bias")])
124
+
125
+ self.downsample.weight.copy_(proj_conv_weight)
126
+ self.gn_proj.weight.copy_(proj_gn_weight.view(-1))
127
+ self.gn_proj.bias.copy_(proj_gn_bias.view(-1))
128
+
129
+ class ResNetV2(nn.Module):
130
+ """Implementation of Pre-activation (v2) ResNet mode."""
131
+
132
+ def __init__(self, block_units, width_factor):
133
+ super().__init__()
134
+ width = int(64 * width_factor)
135
+ self.width = width
136
+
137
+ # The following will be unreadable if we split lines.
138
+ # pylint: disable=line-too-long
139
+ self.root = nn.Sequential(OrderedDict([
140
+ ('conv', StdConv2d(3, width, kernel_size=7, stride=2, bias=False, padding=3)),
141
+ ('gn', nn.GroupNorm(32, width, eps=1e-6)),
142
+ ('relu', nn.ReLU(inplace=True)),
143
+ ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0))
144
+ ]))
145
+
146
+ self.body = nn.Sequential(OrderedDict([
147
+ ('block1', nn.Sequential(OrderedDict(
148
+ [('unit1', PreActBottleneck(cin=width, cout=width*4, cmid=width))] +
149
+ [(f'unit{i:d}', PreActBottleneck(cin=width*4, cout=width*4, cmid=width)) for i in range(2, block_units[0] + 1)],
150
+ ))),
151
+ ('block2', nn.Sequential(OrderedDict(
152
+ [('unit1', PreActBottleneck(cin=width*4, cout=width*8, cmid=width*2, stride=2))] +
153
+ [(f'unit{i:d}', PreActBottleneck(cin=width*8, cout=width*8, cmid=width*2)) for i in range(2, block_units[1] + 1)],
154
+ ))),
155
+ ('block3', nn.Sequential(OrderedDict(
156
+ [('unit1', PreActBottleneck(cin=width*8, cout=width*16, cmid=width*4, stride=2))] +
157
+ [(f'unit{i:d}', PreActBottleneck(cin=width*16, cout=width*16, cmid=width*4)) for i in range(2, block_units[2] + 1)],
158
+ ))),
159
+ ]))
160
+
161
+ def forward(self, x):
162
+ x = self.root(x)
163
+ x = self.body(x)
164
+ return x
models/utils.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import warnings
4
+ import ml_collections
5
+ import random
6
+ import torch.nn.functional as F
7
+
8
+
9
+ def DiffAugment(x, types=[], prob = 0.5, detach=True):
10
+ """
11
+ x.shape = B, C, H, W
12
+ """
13
+ if random.random() < prob:
14
+ with torch.set_grad_enabled(not detach):
15
+ x = random_hflip(x, prob=0.5)
16
+ for p in types:
17
+ for f in AUGMENT_FNS[p]:
18
+ x = f(x)
19
+ x = x.contiguous()
20
+ return x
21
+
22
+
23
+ def random_hflip(tensor, prob):
24
+ if prob > random.random():
25
+ return tensor
26
+ return torch.flip(tensor, dims=(3,))
27
+
28
+ def rand_brightness(x):
29
+ x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
30
+ return x
31
+
32
+ def rand_saturation(x):
33
+ x_mean = x.mean(dim=1, keepdim=True)
34
+ x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
35
+ return x
36
+
37
+ def rand_contrast(x):
38
+ x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
39
+ x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
40
+ return x
41
+
42
+ def rand_translation(x, ratio=0.125):
43
+ shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
44
+ translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
45
+ translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
46
+ grid_batch, grid_x, grid_y = torch.meshgrid(
47
+ torch.arange(x.size(0), dtype=torch.long, device=x.device),
48
+ torch.arange(x.size(2), dtype=torch.long, device=x.device),
49
+ torch.arange(x.size(3), dtype=torch.long, device=x.device),
50
+ )
51
+ grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
52
+ grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
53
+ x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
54
+ x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
55
+ return x
56
+
57
+ def rand_offset(x, ratio=1, ratio_h=1, ratio_v=1):
58
+ w, h = x.size(2), x.size(3)
59
+
60
+ imgs = []
61
+ for img in x.unbind(dim = 0):
62
+ max_h = int(w * ratio * ratio_h)
63
+ max_v = int(h * ratio * ratio_v)
64
+
65
+ value_h = random.randint(0, max_h) * 2 - max_h
66
+ value_v = random.randint(0, max_v) * 2 - max_v
67
+
68
+ if abs(value_h) > 0:
69
+ img = torch.roll(img, value_h, 2)
70
+
71
+ if abs(value_v) > 0:
72
+ img = torch.roll(img, value_v, 1)
73
+
74
+ imgs.append(img)
75
+
76
+ return torch.stack(imgs)
77
+
78
+ def rand_offset_h(x, ratio=1):
79
+ return rand_offset(x, ratio=1, ratio_h=ratio, ratio_v=0)
80
+
81
+ def rand_offset_v(x, ratio=1):
82
+ return rand_offset(x, ratio=1, ratio_h=0, ratio_v=ratio)
83
+
84
+ def rand_cutout(x, ratio=0.5):
85
+ cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
86
+ offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
87
+ offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
88
+ grid_batch, grid_x, grid_y = torch.meshgrid(
89
+ torch.arange(x.size(0), dtype=torch.long, device=x.device),
90
+ torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
91
+ torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
92
+ )
93
+ grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
94
+ grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
95
+ mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
96
+ mask[grid_batch, grid_x, grid_y] = 0
97
+ x = x * mask.unsqueeze(1)
98
+ return x
99
+
100
+
101
+ AUGMENT_FNS = {
102
+ 'color': [rand_brightness, rand_saturation, rand_contrast],
103
+ 'offset': [rand_offset],
104
+ 'offset_h': [rand_offset_h],
105
+ 'offset_v': [rand_offset_v],
106
+ 'translation': [rand_translation],
107
+ 'cutout': [rand_cutout],
108
+ }
109
+
110
+
111
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
112
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
113
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
114
+ def norm_cdf(x):
115
+ # Computes standard normal cumulative distribution function
116
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
117
+
118
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
119
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
120
+ "The distribution of values may be incorrect.",
121
+ stacklevel=2)
122
+
123
+ with torch.no_grad():
124
+ # Values are generated by using a truncated uniform distribution and
125
+ # then using the inverse CDF for the normal distribution.
126
+ # Get upper and lower cdf values
127
+ l = norm_cdf((a - mean) / std)
128
+ u = norm_cdf((b - mean) / std)
129
+
130
+ # Uniformly fill tensor with values from [l, u], then translate to
131
+ # [2l-1, 2u-1].
132
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
133
+
134
+ # Use inverse cdf transform for normal distribution to get truncated
135
+ # standard normal
136
+ tensor.erfinv_()
137
+
138
+ # Transform to proper mean, std
139
+ tensor.mul_(std * math.sqrt(2.))
140
+ tensor.add_(mean)
141
+
142
+ # Clamp to ensure it's in the proper range
143
+ tensor.clamp_(min=a, max=b)
144
+ return tensor
145
+
146
+
147
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
148
+ # type: (Tensor, float, float, float, float) -> Tensor
149
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
150
+
151
+
152
+ def get_testing():
153
+ """Returns a minimal configuration for testing."""
154
+ config = ml_collections.ConfigDict()
155
+ config.patches = ml_collections.ConfigDict({'size': (16, 16)})
156
+ config.hidden_size = 1
157
+ config.transformer = ml_collections.ConfigDict()
158
+ config.transformer.mlp_dim = 1
159
+ config.transformer.num_heads = 1
160
+ config.transformer.num_layers = 1
161
+ config.transformer.attention_dropout_rate = 0.0
162
+ config.transformer.dropout_rate = 0.1
163
+ config.classifier = 'token'
164
+ config.representation_size = None
165
+ return config
166
+
167
+
168
+ def get_b16_config():
169
+ """Returns the ViT-B/16 configuration."""
170
+ config = ml_collections.ConfigDict()
171
+ config.patches = ml_collections.ConfigDict({'size': (16, 16)})
172
+ config.hidden_size = 768
173
+ config.transformer = ml_collections.ConfigDict()
174
+ config.transformer.mlp_dim = 3072
175
+ config.transformer.num_heads = 12
176
+ config.transformer.num_layers = 12
177
+ config.transformer.attention_dropout_rate = 0.0
178
+ config.transformer.dropout_rate = 0.1
179
+ config.classifier = 'token'
180
+ config.representation_size = None
181
+ return config
182
+
183
+
184
+ def get_r50_b16_config():
185
+ """Returns the Resnet50 + ViT-B/16 configuration."""
186
+ config = get_b16_config()
187
+ del config.patches.size
188
+ config.patches.grid = (14, 14)
189
+ config.resnet = ml_collections.ConfigDict()
190
+ config.resnet.num_layers = (3, 4, 9)
191
+ config.resnet.width_factor = 1
192
+ return config
193
+
194
+
195
+ def get_b32_config():
196
+ """Returns the ViT-B/32 configuration."""
197
+ config = get_b16_config()
198
+ config.patches.size = (32, 32)
199
+ return config
200
+
201
+
202
+ def get_l16_config():
203
+ """Returns the ViT-L/16 configuration."""
204
+ config = ml_collections.ConfigDict()
205
+ config.patches = ml_collections.ConfigDict({'size': (16, 16)})
206
+ config.hidden_size = 1024
207
+ config.transformer = ml_collections.ConfigDict()
208
+ config.transformer.mlp_dim = 4096
209
+ config.transformer.num_heads = 16
210
+ config.transformer.num_layers = 24
211
+ config.transformer.attention_dropout_rate = 0.0
212
+ config.transformer.dropout_rate = 0.1
213
+ config.classifier = 'token'
214
+ config.representation_size = None
215
+ return config
216
+
217
+
218
+ def get_l32_config():
219
+ """Returns the ViT-L/32 configuration."""
220
+ config = get_l16_config()
221
+ config.patches.size = (32, 32)
222
+ return config
223
+
224
+
225
+ def get_h14_config():
226
+ """Returns the ViT-L/16 configuration."""
227
+ config = ml_collections.ConfigDict()
228
+ config.patches = ml_collections.ConfigDict({'size': (14, 14)})
229
+ config.hidden_size = 1280
230
+ config.transformer = ml_collections.ConfigDict()
231
+ config.transformer.mlp_dim = 5120
232
+ config.transformer.num_heads = 16
233
+ config.transformer.num_layers = 32
234
+ config.transformer.attention_dropout_rate = 0.0
235
+ config.transformer.dropout_rate = 0.1
236
+ config.classifier = 'token'
237
+ config.representation_size = None
238
+ return config
models/vision_transformer.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ import math
5
+ from functools import partial
6
+ from .utils import trunc_normal_
7
+
8
+
9
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
10
+ if drop_prob == 0. or not training:
11
+ return x
12
+ keep_prob = 1 - drop_prob
13
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
14
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
15
+ random_tensor.floor_() # binarize
16
+ output = x.div(keep_prob) * random_tensor
17
+ return output
18
+
19
+
20
+ class DropPath(nn.Module):
21
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
22
+ """
23
+ def __init__(self, drop_prob=None):
24
+ super(DropPath, self).__init__()
25
+ self.drop_prob = drop_prob
26
+
27
+ def forward(self, x):
28
+ return drop_path(x, self.drop_prob, self.training)
29
+
30
+
31
+ class Mlp(nn.Module):
32
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
33
+ super().__init__()
34
+ out_features = out_features or in_features
35
+ hidden_features = hidden_features or in_features
36
+ self.fc1 = nn.Linear(in_features, hidden_features)
37
+ self.act = act_layer()
38
+ self.fc2 = nn.Linear(hidden_features, out_features)
39
+ self.drop = nn.Dropout(drop)
40
+
41
+ def forward(self, x):
42
+ x = self.fc1(x)
43
+ x = self.act(x)
44
+ x = self.drop(x)
45
+ x = self.fc2(x)
46
+ x = self.drop(x)
47
+ return x
48
+
49
+
50
+ class Attention(nn.Module):
51
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
52
+ super().__init__()
53
+ self.num_heads = num_heads
54
+ head_dim = dim // num_heads
55
+ self.scale = qk_scale or head_dim ** -0.5
56
+
57
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
58
+ self.attn_drop = nn.Dropout(attn_drop)
59
+ self.proj = nn.Linear(dim, dim)
60
+ self.proj_drop = nn.Dropout(proj_drop)
61
+
62
+ def forward(self, x):
63
+ B, N, C = x.shape
64
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
65
+ q, k, v = qkv[0], qkv[1], qkv[2]
66
+
67
+ attn = (q @ k.transpose(-2, -1)) * self.scale
68
+ attn = attn.softmax(dim=-1)
69
+ attn = self.attn_drop(attn)
70
+
71
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
72
+ x = self.proj(x)
73
+ x = self.proj_drop(x)
74
+ return x, attn
75
+
76
+
77
+ class Block(nn.Module):
78
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
79
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
80
+ super().__init__()
81
+ self.norm1 = norm_layer(dim)
82
+ self.attn = Attention(
83
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
84
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
85
+ self.norm2 = norm_layer(dim)
86
+ mlp_hidden_dim = int(dim * mlp_ratio)
87
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
88
+
89
+ def forward(self, x, return_attention=False):
90
+ y, attn = self.attn(self.norm1(x))
91
+ if return_attention:
92
+ return attn
93
+ x = x + self.drop_path(y)
94
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
95
+ return x
96
+
97
+
98
+ class PatchEmbed(nn.Module):
99
+ """ Image to Patch Embedding
100
+ """
101
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
102
+ super().__init__()
103
+ num_patches = (img_size // patch_size) * (img_size // patch_size)
104
+ self.img_size = img_size
105
+ self.patch_size = patch_size
106
+ self.num_patches = num_patches
107
+
108
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
109
+
110
+ def forward(self, x):
111
+ B, C, H, W = x.shape
112
+ x = self.proj(x).flatten(2).transpose(1, 2)
113
+ return x
114
+
115
+
116
+ class VisionTransformer(nn.Module):
117
+ """ Vision Transformer """
118
+ def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12,
119
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
120
+ drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs):
121
+ super().__init__()
122
+ self.num_features = self.embed_dim = embed_dim
123
+
124
+ self.patch_embed = PatchEmbed(
125
+ img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
126
+ num_patches = self.patch_embed.num_patches
127
+
128
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
129
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
130
+ self.pos_drop = nn.Dropout(p=drop_rate)
131
+
132
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
133
+ self.blocks = nn.ModuleList([
134
+ Block(
135
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
136
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
137
+ for i in range(depth)])
138
+ self.norm = norm_layer(embed_dim)
139
+
140
+ # Classifier head
141
+ self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
142
+
143
+ trunc_normal_(self.pos_embed, std=.02)
144
+ trunc_normal_(self.cls_token, std=.02)
145
+ self.apply(self._init_weights)
146
+
147
+ def _init_weights(self, m):
148
+ if isinstance(m, nn.Linear):
149
+ trunc_normal_(m.weight, std=.02)
150
+ if isinstance(m, nn.Linear) and m.bias is not None:
151
+ nn.init.constant_(m.bias, 0)
152
+ elif isinstance(m, nn.LayerNorm):
153
+ nn.init.constant_(m.bias, 0)
154
+ nn.init.constant_(m.weight, 1.0)
155
+
156
+ def interpolate_pos_encoding(self, x, w, h):
157
+ npatch = x.shape[1] - 1
158
+ N = self.pos_embed.shape[1] - 1
159
+ if npatch == N and w == h:
160
+ return self.pos_embed
161
+ class_pos_embed = self.pos_embed[:, 0]
162
+ patch_pos_embed = self.pos_embed[:, 1:]
163
+ dim = x.shape[-1]
164
+ w0 = w // self.patch_embed.patch_size
165
+ h0 = h // self.patch_embed.patch_size
166
+ # we add a small number to avoid floating point error in the interpolation
167
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
168
+ w0, h0 = w0 + 0.1, h0 + 0.1
169
+ patch_pos_embed = nn.functional.interpolate(
170
+ patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
171
+ scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
172
+ mode='bicubic',
173
+ align_corners=False,
174
+ recompute_scale_factor=False
175
+ )
176
+ assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
177
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
178
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
179
+
180
+ def prepare_tokens(self, x, ada_token=None):
181
+ B, nc, w, h = x.shape
182
+ x = self.patch_embed(x) # patch linear embedding
183
+
184
+ # add the [CLS] token to the embed patch tokens
185
+ cls_tokens = self.cls_token.expand(B, -1, -1)
186
+ x = torch.cat((cls_tokens, x), dim=1)
187
+
188
+ # add positional encoding to each token
189
+ x = x + self.interpolate_pos_encoding(x, w, h)
190
+
191
+ if ada_token is not None:
192
+ ada_tokens = ada_token.expand(B, -1, -1) # B, p, d
193
+ x = torch.cat((x, ada_tokens), dim=1)
194
+
195
+ return self.pos_drop(x)
196
+
197
+ def forward(self, x, ada_token=None, use_patches=False):
198
+ x = self.prepare_tokens(x, ada_token)
199
+ for blk in self.blocks:
200
+ x = blk(x)
201
+ x = self.norm(x)
202
+
203
+ if use_patches:
204
+ return x[:, 1:]
205
+ else:
206
+ return x[:, 0]
207
+
208
+ def get_last_selfattention(self, x):
209
+ x = self.prepare_tokens(x)
210
+ for i, blk in enumerate(self.blocks):
211
+ if i < len(self.blocks) - 1:
212
+ x = blk(x)
213
+ else:
214
+ # return attention of the last block
215
+ return blk(x, return_attention=True)
216
+
217
+ def get_intermediate_layers(self, x, n=1):
218
+ x = self.prepare_tokens(x)
219
+ # we return the output tokens from the `n` last blocks
220
+ output = []
221
+ for i, blk in enumerate(self.blocks):
222
+ x = blk(x)
223
+ if len(self.blocks) - i <= n:
224
+ output.append(self.norm(x))
225
+ return output
226
+
227
+
228
+ def vit_tiny(patch_size=16, **kwargs):
229
+ model = VisionTransformer(
230
+ patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,
231
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
232
+ return model
233
+
234
+
235
+ def vit_small(patch_size=16, **kwargs):
236
+ model = VisionTransformer(
237
+ patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
238
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
239
+ return model
240
+
241
+
242
+ def vit_base(patch_size=16, **kwargs):
243
+ model = VisionTransformer(
244
+ patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
245
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
246
+ return model
models/vit_google.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import logging
3
+ import math
4
+
5
+ from os.path import join as pjoin
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import numpy as np
10
+
11
+ from torch.nn import Dropout, Softmax, Linear, Conv2d, LayerNorm
12
+ from torch.nn.modules.utils import _pair
13
+ from scipy import ndimage
14
+
15
+ from .utils import get_b16_config
16
+ from .resnet_v2 import ResNetV2
17
+
18
+
19
+ CONFIGS = {
20
+ 'ViT-B_16': get_b16_config(),
21
+ #'ViT-B_32': get_b32_config(),
22
+ #'ViT-L_16': get_l16_config(),
23
+ #'ViT-L_32': get_l32_config(),
24
+ #'ViT-H_14': get_h14_config(),
25
+ #'R50-ViT-B_16': get_r50_b16_config(),
26
+ #'testing': configs.get_testing(),
27
+ }
28
+
29
+ ATTENTION_Q = "MultiHeadDotProductAttention_1/query"
30
+ ATTENTION_K = "MultiHeadDotProductAttention_1/key"
31
+ ATTENTION_V = "MultiHeadDotProductAttention_1/value"
32
+ ATTENTION_OUT = "MultiHeadDotProductAttention_1/out"
33
+ FC_0 = "MlpBlock_3/Dense_0"
34
+ FC_1 = "MlpBlock_3/Dense_1"
35
+ ATTENTION_NORM = "LayerNorm_0"
36
+ MLP_NORM = "LayerNorm_2"
37
+
38
+
39
+ def np2th(weights, conv=False):
40
+ """Possibly convert HWIO to OIHW."""
41
+ if conv:
42
+ weights = weights.transpose([3, 2, 0, 1])
43
+ return torch.from_numpy(weights)
44
+
45
+
46
+ def swish(x):
47
+ return x * torch.sigmoid(x)
48
+
49
+
50
+ ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish}
51
+
52
+
53
+ class Attention(nn.Module):
54
+ def __init__(self, config, vis):
55
+ super(Attention, self).__init__()
56
+ self.vis = vis
57
+ self.num_attention_heads = config.transformer["num_heads"]
58
+ self.attention_head_size = int(config.hidden_size / self.num_attention_heads)
59
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
60
+
61
+ self.query = Linear(config.hidden_size, self.all_head_size)
62
+ self.key = Linear(config.hidden_size, self.all_head_size)
63
+ self.value = Linear(config.hidden_size, self.all_head_size)
64
+
65
+ self.out = Linear(config.hidden_size, config.hidden_size)
66
+ self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"])
67
+ self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"])
68
+
69
+ self.softmax = Softmax(dim=-1)
70
+
71
+ def transpose_for_scores(self, x):
72
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
73
+ x = x.view(*new_x_shape)
74
+ return x.permute(0, 2, 1, 3)
75
+
76
+ def forward(self, hidden_states):
77
+ mixed_query_layer = self.query(hidden_states)
78
+ mixed_key_layer = self.key(hidden_states)
79
+ mixed_value_layer = self.value(hidden_states)
80
+
81
+ query_layer = self.transpose_for_scores(mixed_query_layer)
82
+ key_layer = self.transpose_for_scores(mixed_key_layer)
83
+ value_layer = self.transpose_for_scores(mixed_value_layer)
84
+
85
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
86
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
87
+ attention_probs = self.softmax(attention_scores)
88
+ weights = attention_probs if self.vis else None
89
+ attention_probs = self.attn_dropout(attention_probs)
90
+
91
+ context_layer = torch.matmul(attention_probs, value_layer)
92
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
93
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
94
+ context_layer = context_layer.view(*new_context_layer_shape)
95
+ attention_output = self.out(context_layer)
96
+ attention_output = self.proj_dropout(attention_output)
97
+ return attention_output, weights
98
+
99
+
100
+ class Mlp(nn.Module):
101
+ def __init__(self, config):
102
+ super(Mlp, self).__init__()
103
+ self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"])
104
+ self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size)
105
+ self.act_fn = ACT2FN["gelu"]
106
+ self.dropout = Dropout(config.transformer["dropout_rate"])
107
+
108
+ self._init_weights()
109
+
110
+ def _init_weights(self):
111
+ nn.init.xavier_uniform_(self.fc1.weight)
112
+ nn.init.xavier_uniform_(self.fc2.weight)
113
+ nn.init.normal_(self.fc1.bias, std=1e-6)
114
+ nn.init.normal_(self.fc2.bias, std=1e-6)
115
+
116
+ def forward(self, x):
117
+ x = self.fc1(x)
118
+ x = self.act_fn(x)
119
+ x = self.dropout(x)
120
+ x = self.fc2(x)
121
+ x = self.dropout(x)
122
+ return x
123
+
124
+
125
+ class Embeddings(nn.Module):
126
+ """Construct the embeddings from patch, position embeddings.
127
+ """
128
+ def __init__(self, config, img_size, in_channels=3):
129
+ super(Embeddings, self).__init__()
130
+ self.hybrid = None
131
+ img_size = _pair(img_size)
132
+
133
+ if config.patches.get("grid") is not None:
134
+ grid_size = config.patches["grid"]
135
+ patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1])
136
+ n_patches = (img_size[0] // 16) * (img_size[1] // 16)
137
+ self.hybrid = True
138
+ else:
139
+ patch_size = _pair(config.patches["size"])
140
+ n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
141
+ self.hybrid = False
142
+
143
+ if self.hybrid:
144
+ self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers,
145
+ width_factor=config.resnet.width_factor)
146
+ in_channels = self.hybrid_model.width * 16
147
+ self.patch_size = patch_size
148
+ self.patch_embeddings = Conv2d(in_channels=in_channels,
149
+ out_channels=config.hidden_size,
150
+ kernel_size=patch_size,
151
+ stride=patch_size)
152
+ self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches+1, config.hidden_size))
153
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
154
+
155
+ self.dropout = Dropout(config.transformer["dropout_rate"])
156
+
157
+ def interpolate_pos_encoding(self, x, h, w):
158
+ npatch = x.shape[1] - 1
159
+ N = self.position_embeddings.shape[1] - 1
160
+ if npatch == N and w == h:
161
+ return self.position_embeddings
162
+ class_pos_embed = self.position_embeddings[:, 0]
163
+ patch_pos_embed = self.position_embeddings[:, 1:]
164
+ dim = x.shape[-1]
165
+ w0 = w // self.patch_size[0]
166
+ h0 = h // self.patch_size[1]
167
+ # we add a small number to avoid floating point error in the interpolation
168
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
169
+ w0, h0 = w0 + 0.1, h0 + 0.1
170
+ patch_pos_embed = nn.functional.interpolate(
171
+ patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
172
+ scale_factor=(h0 / math.sqrt(N), w0 / math.sqrt(N)),
173
+ mode='bicubic',
174
+ align_corners=False,
175
+ recompute_scale_factor=False
176
+ )
177
+ assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1]
178
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
179
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
180
+
181
+ def forward(self, x):
182
+ B, nc, h, w = x.shape
183
+ cls_tokens = self.cls_token.expand(B, -1, -1)
184
+
185
+ if self.hybrid:
186
+ x = self.hybrid_model(x)
187
+
188
+ # Linear embedding
189
+ x = self.patch_embeddings(x)
190
+
191
+ # add the [CLS] token to the embed patch tokens
192
+ x = x.flatten(2)
193
+ x = x.transpose(-1, -2)
194
+ x = torch.cat((cls_tokens, x), dim=1)
195
+
196
+ # add positional encoding to each token
197
+ embeddings = x + self.interpolate_pos_encoding(x, h, w)
198
+ embeddings = self.dropout(embeddings)
199
+ return embeddings
200
+
201
+
202
+ class Block(nn.Module):
203
+ def __init__(self, config, vis):
204
+ super(Block, self).__init__()
205
+ self.hidden_size = config.hidden_size
206
+ self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6)
207
+ self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6)
208
+ self.ffn = Mlp(config)
209
+ self.attn = Attention(config, vis)
210
+
211
+ def forward(self, x):
212
+ h = x
213
+ x = self.attention_norm(x)
214
+ x, weights = self.attn(x)
215
+ x = x + h
216
+
217
+ h = x
218
+ x = self.ffn_norm(x)
219
+ x = self.ffn(x)
220
+ x = x + h
221
+ return x, weights
222
+
223
+ def load_from(self, weights, n_block):
224
+ ROOT = f"Transformer/encoderblock_{n_block}"
225
+ with torch.no_grad():
226
+ query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t()
227
+ key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t()
228
+ value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, self.hidden_size).t()
229
+ out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, self.hidden_size).t()
230
+
231
+ query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1)
232
+ key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1)
233
+ value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1)
234
+ out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1)
235
+
236
+ self.attn.query.weight.copy_(query_weight)
237
+ self.attn.key.weight.copy_(key_weight)
238
+ self.attn.value.weight.copy_(value_weight)
239
+ self.attn.out.weight.copy_(out_weight)
240
+ self.attn.query.bias.copy_(query_bias)
241
+ self.attn.key.bias.copy_(key_bias)
242
+ self.attn.value.bias.copy_(value_bias)
243
+ self.attn.out.bias.copy_(out_bias)
244
+
245
+ mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t()
246
+ mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t()
247
+ mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t()
248
+ mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t()
249
+
250
+ self.ffn.fc1.weight.copy_(mlp_weight_0)
251
+ self.ffn.fc2.weight.copy_(mlp_weight_1)
252
+ self.ffn.fc1.bias.copy_(mlp_bias_0)
253
+ self.ffn.fc2.bias.copy_(mlp_bias_1)
254
+
255
+ self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")]))
256
+ self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")]))
257
+ self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")]))
258
+ self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")]))
259
+
260
+
261
+ class Encoder(nn.Module):
262
+ def __init__(self, config, vis):
263
+ super(Encoder, self).__init__()
264
+ self.vis = vis
265
+ self.layer = nn.ModuleList()
266
+ self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6)
267
+ for _ in range(config.transformer["num_layers"]):
268
+ layer = Block(config, vis)
269
+ self.layer.append(copy.deepcopy(layer))
270
+
271
+ def forward(self, hidden_states):
272
+ attn_weights = []
273
+ for layer_block in self.layer:
274
+ hidden_states, weights = layer_block(hidden_states)
275
+ if self.vis:
276
+ attn_weights.append(weights)
277
+ encoded = self.encoder_norm(hidden_states)
278
+ return encoded, attn_weights
279
+
280
+
281
+ class Transformer(nn.Module):
282
+ def __init__(self, config, img_size, vis):
283
+ super(Transformer, self).__init__()
284
+ self.embeddings = Embeddings(config, img_size=img_size)
285
+ self.encoder = Encoder(config, vis)
286
+
287
+ def forward(self, input_ids):
288
+ embedding_output = self.embeddings(input_ids)
289
+ encoded, attn_weights = self.encoder(embedding_output)
290
+ return encoded, attn_weights
291
+
292
+
293
+ class VisionTransformer(nn.Module):
294
+ def __init__(self, config, img_size=224, vis=False):
295
+ super(VisionTransformer, self).__init__()
296
+ #self.num_classes = num_classes
297
+ #self.classifier = config.classifier
298
+ self.embed_dim = config.hidden_size
299
+
300
+ self.transformer = Transformer(config, img_size, vis)
301
+ #self.head = Linear(config.hidden_size, num_classes)
302
+
303
+ def forward(self, x, labels=None, use_patches=False):
304
+ x, attn_weights = self.transformer(x)
305
+ #logits = self.head(x[:, 0])
306
+
307
+ #if labels is not None:
308
+ # loss_fct = CrossEntropyLoss()
309
+ # loss = loss_fct(logits.view(-1, self.num_classes), labels.view(-1))
310
+ # return loss
311
+ #else:
312
+ # return logits, attn_weights
313
+
314
+ if use_patches:
315
+ return x[:, 1:]
316
+ else:
317
+ return x[:, 0]
318
+
319
+ def load_from(self, weights):
320
+ with torch.no_grad():
321
+ #if self.zero_head:
322
+ # nn.init.zeros_(self.head.weight)
323
+ # nn.init.zeros_(self.head.bias)
324
+ #else:
325
+ # self.head.weight.copy_(np2th(weights["head/kernel"]).t())
326
+ # self.head.bias.copy_(np2th(weights["head/bias"]).t())
327
+
328
+ self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True))
329
+ self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"]))
330
+ self.transformer.embeddings.cls_token.copy_(np2th(weights["cls"]))
331
+ self.transformer.encoder.encoder_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"]))
332
+ self.transformer.encoder.encoder_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"]))
333
+
334
+ posemb = np2th(weights["Transformer/posembed_input/pos_embedding"])
335
+ posemb_new = self.transformer.embeddings.position_embeddings
336
+ if posemb.size() == posemb_new.size():
337
+ self.transformer.embeddings.position_embeddings.copy_(posemb)
338
+ else:
339
+ print("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size()))
340
+ ntok_new = posemb_new.size(1)
341
+
342
+ if self.classifier == "token":
343
+ posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:]
344
+ ntok_new -= 1
345
+ else:
346
+ posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
347
+
348
+ gs_old = int(np.sqrt(len(posemb_grid)))
349
+ gs_new = int(np.sqrt(ntok_new))
350
+ print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new))
351
+ posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1)
352
+
353
+ zoom = (gs_new / gs_old, gs_new / gs_old, 1)
354
+ posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1)
355
+ posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1)
356
+ posemb = np.concatenate([posemb_tok, posemb_grid], axis=1)
357
+ self.transformer.embeddings.position_embeddings.copy_(np2th(posemb))
358
+
359
+ for bname, block in self.transformer.encoder.named_children():
360
+ for uname, unit in block.named_children():
361
+ unit.load_from(weights, n_block=uname)
362
+
363
+ if self.transformer.embeddings.hybrid:
364
+ self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(weights["conv_root/kernel"], conv=True))
365
+ gn_weight = np2th(weights["gn_root/scale"]).view(-1)
366
+ gn_bias = np2th(weights["gn_root/bias"]).view(-1)
367
+ self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight)
368
+ self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias)
369
+
370
+ for bname, block in self.transformer.embeddings.hybrid_model.body.named_children():
371
+ for uname, unit in block.named_children():
372
+ unit.load_from(weights, n_block=bname, n_unit=uname)
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==1.13.1
2
+ torchvision
3
+ numpy<2.0
4
+ tqdm
5
+ matplotlib
6
+ dotmap
7
+ Pillow
8
+ timm
9
+ ml-collections
10
+ ftfy
11
+ tensorboard
12
+ Google-Images-Search
13
+ semantic-version
14
+ pytz