JH-C-k commited on
Commit
2642b57
·
verified ·
1 Parent(s): 9acab03

Add files using upload-large-folder tool

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ vitl14_attention.png filter=lfs diff=lfs merge=lfs -text
37
+ vitl14_patchnorms.png filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ license: mit
4
+ pipeline_tag: image-feature-extraction
5
+ tags:
6
+ - clip
7
+ ---
8
+
9
+ # OpenCLIP ViT-L/14 with Test-Time Register
10
+
11
+ Register tokens in ViTs were introduced as learnable tokens in [Vision Transformers Need Registers](https://arxiv.org/abs/2309.16588) to mitigate artifacts in intermediate feature maps.
12
+ In [Vision Transformers Don't Need *Trained* Registers](https://arxiv.org/abs/2506.08010), we introduced a training-free method to create registers. These *test-time registers* serve a similar purpose
13
+ as the original trained registers, but can be added post-hoc to any ViT to mitigate artifacts, enhance model interpretability, and modestly improve downstream performance in tasks such as segmentation, depth estimation, etc.
14
+
15
+ ## Model description
16
+
17
+ The base model is [OpenCLIP-ViT-L-14-laion2B-s32B-b82K](https://huggingface.co/laion/CLIP-ViT-L-14-laion2B-s32B-b82K). With test-time registers, the model's internal representations
18
+ are cleaner (see below). Using the environment from [here](https://github.com/nickjiang2378/test-time-registers/blob/main/environment.yml) and evaluating using bfloat16 leads to IN-1k zeroshot performance of 76.4 for both the original model and the variant with test-time registers.
19
+ This model is intended to be used with this [repo](https://github.com/nickjiang2378/test-time-registers). Use transformers==4.45.1. The model can also be used for fine-tuning or other downstream tasks.
20
+
21
+ <img src="https://huggingface.co/amildravid4292/clip-vitl14-test-time-registers/resolve/main/vitl14_attention.png" alt="drawing" width="600"/>
22
+ <img src="https://huggingface.co/amildravid4292/clip-vitl14-test-time-registers/resolve/main/vitl14_patchnorms.png" alt="drawing" width="600"/>
23
+
24
+
25
+
26
+
27
+ ## Quick Start
28
+
29
+ ```python
30
+ from transformers import AutoModel
31
+ from PIL import Image
32
+ import torch
33
+
34
+ # Load the complete model with all components
35
+ model = AutoModel.from_pretrained(
36
+ "amildravid4292/clip-vitl14-test-time-registers",
37
+ trust_remote_code=True
38
+ )
39
+
40
+ # Check what was loaded
41
+ print(f"Register tokens: {model.num_register_tokens}")
42
+ print(f"Neuron dict: {model.neuron_dict}")
43
+ print(f"Tokenizer available: {model.tokenizer is not None}")
44
+ print(f"Preprocessor available: {model.preprocessor is not None}")
45
+ print(f"Zero-shot classifier available: {model.zeroshot_classifier is not None}")
46
+ ```
47
+
48
+ ## Usage Examples
49
+
50
+
51
+
52
+ ### Image Processing
53
+ ```python
54
+ from PIL import Image
55
+
56
+ # Load and preprocess image
57
+ image = Image.open("your_image.jpg")
58
+ image_tensor = model.preprocess_image(image).unsqueeze(0)
59
+
60
+ image_features = model.encode_image(
61
+ image_tensor
62
+ )
63
+
64
+ # to run inference with the original model without test-time registers
65
+ image_features = model.encode_image(
66
+ image_tensor,
67
+ neuron_dict=None,
68
+ num_register_tokens=0
69
+ )
70
+
71
+ ```
72
+
73
+ ### Text Processing
74
+ ```python
75
+ # Tokenize text
76
+ text = ["a photo of a cat", "a photo of a dog"]
77
+ text_tokens = model.tokenize(text)
78
+
79
+ # Encode text
80
+ text_features = model.encode_text(text_tokens)
81
+ ```
82
+
83
+
84
+
85
+ ### Complete Pipeline
86
+ ```python
87
+
88
+ # load model
89
+ model = AutoModel.from_pretrained('amildravid4292/clip-vitl14-test-time-registers', trust_remote_code=True)
90
+ model = model.to(device).bfloat16()
91
+ classifier = model.zeroshot_classifier.to(device).bfloat16()
92
+
93
+ # load data
94
+ imagenet_dataset = ImageNet(root='/datasets/ilsvrc/current', split='val', transform=model.preprocessor)
95
+ ground_truth_labels = [imagenet_dataset.targets[i] for i in range(len(imagenet_dataset))]
96
+ loader = torch.utils.data.DataLoader(imagenet_dataset, batch_size=100, num_workers=4, pin_memory=True, shuffle=False)
97
+
98
+ # run zero-shot classification
99
+ with torch.no_grad():
100
+ correct = [0, 0]
101
+ for i, (images, target) in enumerate(tqdm(loader)):
102
+ images = images.to(device).bfloat16()
103
+
104
+ target = target.to(device).bfloat16()
105
+
106
+
107
+ # predict
108
+ image_features = model.encode_image(images)
109
+
110
+ image_features /= image_features.norm(dim=-1, keepdim=True)
111
+ logits = 100. * image_features @ classifier
112
+
113
+ pred = logits.argmax(dim=-1)
114
+ correct[0] += (pred == target).sum().item()
115
+ correct[1] += target.size(0)
116
+
117
+
118
+
119
+ print(correct[0]/correct[1])
120
+ ```
121
+
122
+ ## Advanced Usage
123
+
124
+ ### Custom Neuron Modifications
125
+ ```python
126
+ # Override the saved neuron configuration
127
+ custom_neuron_dict = {0: [10, 20, 30]} # Modify neurons 10,20,30 in layer 0
128
+
129
+ image_features = model.encode_image(
130
+ image_tensor,
131
+ num_register_tokens=4,
132
+ neuron_dict=custom_neuron_dict
133
+ )
134
+ ```
135
+
136
+ ### Different Register Token Counts
137
+ ```python
138
+ # Use different number of register tokens
139
+ image_features = model.encode_image(
140
+ image_tensor,
141
+ num_register_tokens=8 # Override the default
142
+ )
143
+ ```
144
+
145
+ ## Model Details
146
+
147
+ - **Base Architecture**: ViT-L/14
148
+ - **Training Data**: LAION-2B subset
149
+
150
+
151
+ ### BibTeX entry and citation info
152
+
153
+ ```bibtex
154
+ @misc{jiang2025visiontransformersdontneed,
155
+ title={Vision Transformers Don't Need Trained Registers},
156
+ author={Nick Jiang and Amil Dravid and Alexei Efros and Yossi Gandelsman},
157
+ year={2025},
158
+ eprint={2506.08010},
159
+ archivePrefix={arXiv},
160
+ primaryClass={cs.CV},
161
+ url={https://arxiv.org/abs/2506.08010},
162
+ }
163
+ ```
ViT-L-14-336_register.json ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "9": [
3
+ 815,
4
+ 4078,
5
+ 3618,
6
+ 2693,
7
+ 3973,
8
+ 1744,
9
+ 1983,
10
+ 1157,
11
+ 1309,
12
+ 1335,
13
+ 2607,
14
+ 2396,
15
+ 3049,
16
+ 1610,
17
+ 2621,
18
+ 2867,
19
+ 2012,
20
+ 1924,
21
+ 2394,
22
+ 3097,
23
+ 3125,
24
+ 3959,
25
+ 3210,
26
+ 2855,
27
+ 3609,
28
+ 526,
29
+ 3362,
30
+ 3395,
31
+ 2626,
32
+ 503,
33
+ 2941,
34
+ 3696,
35
+ 1823,
36
+ 2000,
37
+ 129,
38
+ 3667,
39
+ 1372,
40
+ 147,
41
+ 1150,
42
+ 852,
43
+ 3222
44
+ ],
45
+ "8": [
46
+ 745,
47
+ 3249,
48
+ 2585,
49
+ 1537,
50
+ 200,
51
+ 1603,
52
+ 1851,
53
+ 3523,
54
+ 3697,
55
+ 3137,
56
+ 2563,
57
+ 2293,
58
+ 730,
59
+ 906,
60
+ 1528,
61
+ 3348,
62
+ 2438,
63
+ 1564,
64
+ 1540,
65
+ 3238,
66
+ 3606
67
+ ],
68
+ "10": [
69
+ 357,
70
+ 1654,
71
+ 3940,
72
+ 2319,
73
+ 2560,
74
+ 2559,
75
+ 4009,
76
+ 3029,
77
+ 951,
78
+ 1903,
79
+ 738,
80
+ 1602,
81
+ 1807,
82
+ 2018,
83
+ 1281,
84
+ 267,
85
+ 3539,
86
+ 1015,
87
+ 496,
88
+ 693,
89
+ 2278,
90
+ 7,
91
+ 856,
92
+ 2785,
93
+ 2690,
94
+ 1367
95
+ ],
96
+ "7": [
97
+ 3228,
98
+ 2550,
99
+ 2977,
100
+ 3716,
101
+ 2467
102
+ ],
103
+ "0": [
104
+ 2890,
105
+ 1779,
106
+ 3761
107
+ ],
108
+ "6": [
109
+ 1042,
110
+ 2315,
111
+ 1674
112
+ ],
113
+ "3": [
114
+ 410
115
+ ]
116
+ }
__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
2
+ from factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss
3
+ from factory import list_models, add_model_config, get_model_config, load_checkpoint
4
+ from pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \
5
+ get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
6
+ from tokenizer import SimpleTokenizer, tokenize, decode
7
+ from transform import image_transform, AugmentationCfg
8
+ from openai_templates import OPENAI_IMAGENET_TEMPLATES
__pycache__/imagenet_classes.cpython-310.pyc ADDED
Binary file (21.7 kB). View file
 
__pycache__/misc.cpython-310.pyc ADDED
Binary file (4.05 kB). View file
 
__pycache__/model.cpython-310.pyc ADDED
Binary file (12.1 kB). View file
 
__pycache__/modified_resnet.cpython-310.pyc ADDED
Binary file (6.39 kB). View file
 
__pycache__/shared.cpython-310.pyc ADDED
Binary file (15.9 kB). View file
 
__pycache__/timm_model.cpython-310.pyc ADDED
Binary file (4.04 kB). View file
 
__pycache__/tokenizer.cpython-310.pyc ADDED
Binary file (8.63 kB). View file
 
__pycache__/transformer.cpython-310.pyc ADDED
Binary file (23.1 kB). View file
 
config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "custom_clip_with_registers",
3
+
4
+ "processor_class": "CLIPProcessor",
5
+ "tokenizer_class": "CLIPTokenizerFast",
6
+
7
+ "architectures": ["CustomCLIPModel"],
8
+ "auto_map": {
9
+ "AutoConfig": "modeling_custom_clip.CustomCLIPConfig",
10
+ "AutoModel": "modeling_custom_clip.CustomCLIPModel"
11
+ },
12
+ "vision_config": {
13
+ "hidden_size": 1024,
14
+ "num_hidden_layers": 24,
15
+ "num_attention_heads": 16,
16
+ "image_size": 336,
17
+ "patch_size": 14
18
+ },
19
+ "text_config": {
20
+ "vocab_size": 49408,
21
+ "hidden_size": 768,
22
+ "num_hidden_layers": 12,
23
+ "max_position_embeddings": 77
24
+ },
25
+ "neuron_dict": {},
26
+ "projection_dim": 768,
27
+ "torch_dtype": "float32",
28
+ "transformers_version": "4.21.0"
29
+ }
config_TTR_bak.json ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "custom_clip_with_registers",
3
+
4
+ "processor_class": "CLIPProcessor",
5
+ "tokenizer_class": "CLIPTokenizerFast",
6
+
7
+ "architectures": ["CustomCLIPModel"],
8
+ "auto_map": {
9
+ "AutoConfig": "modeling_custom_clip.CustomCLIPConfig",
10
+ "AutoModel": "modeling_custom_clip.CustomCLIPModel"
11
+ },
12
+ "vision_config": {
13
+ "hidden_size": 1024,
14
+ "num_hidden_layers": 24,
15
+ "num_attention_heads": 16,
16
+ "image_size": 336,
17
+ "patch_size": 14
18
+ },
19
+ "text_config": {
20
+ "vocab_size": 49408,
21
+ "hidden_size": 768,
22
+ "num_hidden_layers": 12,
23
+ "max_position_embeddings": 77
24
+ },
25
+ "num_register_tokens": 1,
26
+ "neuron_dict": {
27
+ "9": [
28
+ 815,
29
+ 4078,
30
+ 3618,
31
+ 2693,
32
+ 3973,
33
+ 1744,
34
+ 1983,
35
+ 1157,
36
+ 1309,
37
+ 1335,
38
+ 2607,
39
+ 2396,
40
+ 3049,
41
+ 1610,
42
+ 2621,
43
+ 2867,
44
+ 2012,
45
+ 1924,
46
+ 2394,
47
+ 3097,
48
+ 3125,
49
+ 3959,
50
+ 3210,
51
+ 2855,
52
+ 3609,
53
+ 526,
54
+ 3362,
55
+ 3395,
56
+ 2626,
57
+ 503,
58
+ 2941,
59
+ 3696,
60
+ 1823,
61
+ 2000,
62
+ 129,
63
+ 3667,
64
+ 1372,
65
+ 147,
66
+ 1150,
67
+ 852,
68
+ 3222
69
+ ],
70
+ "8": [
71
+ 745,
72
+ 3249,
73
+ 2585,
74
+ 1537,
75
+ 200,
76
+ 1603,
77
+ 1851,
78
+ 3523,
79
+ 3697,
80
+ 3137,
81
+ 2563,
82
+ 2293,
83
+ 730,
84
+ 906,
85
+ 1528,
86
+ 3348,
87
+ 2438,
88
+ 1564,
89
+ 1540,
90
+ 3238,
91
+ 3606
92
+ ],
93
+ "10": [
94
+ 357,
95
+ 1654,
96
+ 3940,
97
+ 2319,
98
+ 2560,
99
+ 2559,
100
+ 4009,
101
+ 3029,
102
+ 951,
103
+ 1903,
104
+ 738,
105
+ 1602,
106
+ 1807,
107
+ 2018,
108
+ 1281,
109
+ 267,
110
+ 3539,
111
+ 1015,
112
+ 496,
113
+ 693,
114
+ 2278,
115
+ 7,
116
+ 856,
117
+ 2785,
118
+ 2690,
119
+ 1367
120
+ ],
121
+ "7": [
122
+ 3228,
123
+ 2550,
124
+ 2977,
125
+ 3716,
126
+ 2467
127
+ ],
128
+ "0": [
129
+ 2890,
130
+ 1779,
131
+ 3761
132
+ ],
133
+ "6": [
134
+ 1042,
135
+ 2315,
136
+ 1674
137
+ ],
138
+ "3": [
139
+ 410
140
+ ]
141
+ },
142
+ "projection_dim": 768,
143
+ "torch_dtype": "float32",
144
+ "transformers_version": "4.21.0"
145
+ }
config_bak.json ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "custom_clip_with_registers",
3
+ "architectures": ["CustomCLIPModel"],
4
+ "auto_map": {
5
+ "AutoConfig": "modeling_custom_clip.CustomCLIPConfig",
6
+ "AutoModel": "modeling_custom_clip.CustomCLIPModel"
7
+ },
8
+ "vision_config": {
9
+ "hidden_size": 1024,
10
+ "num_hidden_layers": 24,
11
+ "num_attention_heads": 16,
12
+ "image_size": 336,
13
+ "patch_size": 14
14
+ },
15
+ "text_config": {
16
+ "vocab_size": 49408,
17
+ "hidden_size": 768,
18
+ "num_hidden_layers": 12,
19
+ "max_position_embeddings": 77
20
+ },
21
+ "num_register_tokens": 1,
22
+ "neuron_dict": {"10": [2924,
23
+ 2520,
24
+ 2936,
25
+ 675,
26
+ 517,
27
+ 1610,
28
+ 88,
29
+ 1950,
30
+ 3098,
31
+ 4082,
32
+ 1237,
33
+ 857,
34
+ 3020,
35
+ 1321,
36
+ 1128,
37
+ 3561,
38
+ 4091,
39
+ 69,
40
+ 3378,
41
+ 2304,
42
+ 977,
43
+ 1762,
44
+ 3598,
45
+ 371,
46
+ 1097],
47
+ "9": [1253,
48
+ 3658,
49
+ 1827,
50
+ 2600,
51
+ 4000,
52
+ 711,
53
+ 2726,
54
+ 615,
55
+ 2654,
56
+ 831,
57
+ 1,
58
+ 1387,
59
+ 2178,
60
+ 1967,
61
+ 2413,
62
+ 901,
63
+ 481,
64
+ 1514,
65
+ 292,
66
+ 692,
67
+ 3094,
68
+ 3470,
69
+ 932,
70
+ 2129],
71
+ "8": [3189,
72
+ 1491,
73
+ 2159,
74
+ 1196,
75
+ 1913,
76
+ 1340,
77
+ 2515,
78
+ 2163,
79
+ 955,
80
+ 1496,
81
+ 1891,
82
+ 1410,
83
+ 3725,
84
+ 632,
85
+ 188,
86
+ 726,
87
+ 1592,
88
+ 1017,
89
+ 1267,
90
+ 995,
91
+ 3465,
92
+ 3510,
93
+ 1494,
94
+ 3467,
95
+ 1896,
96
+ 2779,
97
+ 2309,
98
+ 3389,
99
+ 3682,
100
+ 1968,
101
+ 2904],
102
+ "7": [2226, 2565],
103
+ "6": [1450, 1551, 1024],
104
+ "5": [151, 1282],
105
+ "4": [2207],
106
+ "3": [2298, 2841]},
107
+ "projection_dim": 768,
108
+ "torch_dtype": "float32",
109
+ "transformers_version": "4.21.0"
110
+ }
constants.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
2
+ OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
factory.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import pathlib
5
+ import re
6
+ from copy import deepcopy
7
+ from pathlib import Path
8
+ from typing import Any, Dict, Optional, Tuple, Union
9
+
10
+ import torch
11
+
12
+ from constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
13
+ from model import CLIP, convert_to_custom_text_state_dict,\
14
+ resize_pos_embed, get_cast_dtype
15
+ from openai_models import load_openai_model
16
+ from pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained,\
17
+ list_pretrained_tags_by_model, download_pretrained_from_hf
18
+ from transform import image_transform, AugmentationCfg
19
+ from tokenizer import HFTokenizer, tokenize
20
+
21
+
22
+ HF_HUB_PREFIX = 'hf-hub:'
23
+ _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
24
+ _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
25
+
26
+
27
+ def _natural_key(string_):
28
+ return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
29
+
30
+
31
+ def _rescan_model_configs():
32
+ global _MODEL_CONFIGS
33
+
34
+ config_ext = ('.json',)
35
+ config_files = []
36
+ for config_path in _MODEL_CONFIG_PATHS:
37
+ if config_path.is_file() and config_path.suffix in config_ext:
38
+ config_files.append(config_path)
39
+ elif config_path.is_dir():
40
+ for ext in config_ext:
41
+ config_files.extend(config_path.glob(f'*{ext}'))
42
+
43
+ for cf in config_files:
44
+ with open(cf, 'r') as f:
45
+ model_cfg = json.load(f)
46
+ if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
47
+ _MODEL_CONFIGS[cf.stem] = model_cfg
48
+
49
+ _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))}
50
+
51
+
52
+ _rescan_model_configs() # initial populate of model config registry
53
+
54
+
55
+ def list_models():
56
+ """ enumerate available model architectures based on config files """
57
+ return list(_MODEL_CONFIGS.keys())
58
+
59
+
60
+ def add_model_config(path):
61
+ """ add model config path or file and update registry """
62
+ if not isinstance(path, Path):
63
+ path = Path(path)
64
+ _MODEL_CONFIG_PATHS.append(path)
65
+ _rescan_model_configs()
66
+
67
+
68
+ def get_model_config(model_name):
69
+ if model_name in _MODEL_CONFIGS:
70
+ return deepcopy(_MODEL_CONFIGS[model_name])
71
+ else:
72
+ return None
73
+
74
+
75
+ def get_tokenizer(model_name):
76
+ if model_name.startswith(HF_HUB_PREFIX):
77
+ tokenizer = HFTokenizer(model_name[len(HF_HUB_PREFIX):])
78
+ else:
79
+ config = get_model_config(model_name)
80
+ tokenizer = HFTokenizer(
81
+ config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize
82
+ return tokenizer
83
+
84
+
85
+ def load_state_dict(checkpoint_path: str, map_location='cpu'):
86
+ checkpoint = torch.load(checkpoint_path, map_location=map_location)
87
+ if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
88
+ state_dict = checkpoint['state_dict']
89
+ else:
90
+ state_dict = checkpoint
91
+ if next(iter(state_dict.items()))[0].startswith('module'):
92
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
93
+ return state_dict
94
+
95
+
96
+ def load_checkpoint(model, checkpoint_path, strict=False):
97
+ state_dict = load_state_dict(checkpoint_path)
98
+ # detect old format and make compatible with new format
99
+ if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
100
+ state_dict = convert_to_custom_text_state_dict(state_dict)
101
+ resize_pos_embed(state_dict, model)
102
+ incompatible_keys = model.load_state_dict(state_dict, strict=strict)
103
+
104
+ model.num_register_tokens=state_dict["num_register_tokens"]
105
+ model.neuron_dict=state_dict["neuron_dict"]
106
+ model.visual.num_register_tokens=state_dict["num_register_tokens"]
107
+ model.visual.neuron_dict=state_dict["neuron_dict"]
108
+
109
+ return incompatible_keys
110
+
111
+
112
+ def create_model(
113
+ model_name: str,
114
+ pretrained: Optional[str] = None,
115
+ precision: str = 'fp32',
116
+ device: Union[str, torch.device] = 'cpu',
117
+ jit: bool = False,
118
+ force_quick_gelu: bool = False,
119
+ force_custom_text: bool = False,
120
+ force_patch_dropout: Optional[float] = None,
121
+ force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
122
+ pretrained_image: bool = False,
123
+ pretrained_hf: bool = True,
124
+ cache_dir: Optional[str] = None,
125
+ output_dict: Optional[bool] = None,
126
+ require_pretrained: bool = False,
127
+ ):
128
+ has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX)
129
+ if has_hf_hub_prefix:
130
+ model_id = model_name[len(HF_HUB_PREFIX):]
131
+ checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
132
+ config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir)
133
+
134
+ with open(config_path, 'r', encoding='utf-8') as f:
135
+ config = json.load(f)
136
+ pretrained_cfg = config['preprocess_cfg']
137
+ model_cfg = config['model_cfg']
138
+ else:
139
+ model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
140
+ checkpoint_path = None
141
+ pretrained_cfg = {}
142
+ model_cfg = None
143
+
144
+
145
+
146
+ if isinstance(device, str):
147
+ device = torch.device(device)
148
+
149
+ if pretrained and pretrained.lower() == 'openai':
150
+ logging.info(f'Loading pretrained {model_name} from OpenAI.')
151
+ model = load_openai_model(
152
+ model_name,
153
+ precision=precision,
154
+ device=device,
155
+ cache_dir=cache_dir,
156
+ quick_gelu=force_quick_gelu,
157
+ )
158
+ else:
159
+ model_cfg = model_cfg or get_model_config(model_name)
160
+
161
+ if model_cfg is not None:
162
+ logging.info(f'Loaded {model_name} model config.')
163
+ else:
164
+ logging.error(f'Model config for {model_name} not found; available models {list_models()}.')
165
+ raise RuntimeError(f'Model config for {model_name} not found.')
166
+
167
+ if force_patch_dropout is not None:
168
+ # override the default patch dropout value
169
+ model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout
170
+
171
+ if force_image_size is not None:
172
+ # override model config's image size
173
+ model_cfg["vision_cfg"]["image_size"] = force_image_size
174
+
175
+ is_timm_model = 'timm_model_name' in model_cfg.get('vision_cfg', {})
176
+ if pretrained_image:
177
+ if is_timm_model:
178
+ # pretrained weight loading for timm models set via vision_cfg
179
+ model_cfg['vision_cfg']['timm_model_pretrained'] = True
180
+ else:
181
+ assert False, 'pretrained image towers currently only supported for timm models'
182
+
183
+ # cast_dtype set for fp16 and bf16 (manual mixed-precision), not set for 'amp' or 'pure' modes
184
+ cast_dtype = get_cast_dtype(precision)
185
+ is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {})
186
+ custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model
187
+
188
+ if custom_text:
189
+ if is_hf_model:
190
+ model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf
191
+ if "coca" in model_name:
192
+ raise ValueError('Coca is not implemented')
193
+ model = CoCa(**model_cfg, cast_dtype=cast_dtype)
194
+ else:
195
+ raise ValueError('CustomTextCLIP is not implemented')
196
+ model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
197
+ else:
198
+ model = CLIP(**model_cfg, cast_dtype=cast_dtype)
199
+
200
+ if precision in ("fp16", "bf16"):
201
+ dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
202
+ # manual mixed precision that matches original OpenAI behaviour
203
+ if is_timm_model:
204
+ # FIXME this is a bit janky, create timm based model in low-precision and
205
+ # then cast only LayerNormFp32 instances back to float32 so they don't break.
206
+ # Why? The convert_weights_to_lp fn only works with native models.
207
+ model.to(device=device, dtype=dtype)
208
+ from transformer import LayerNormFp32
209
+ def _convert_ln(m):
210
+ if isinstance(m, LayerNormFp32):
211
+ m.weight.data = m.weight.data.to(torch.float32)
212
+ m.bias.data = m.bias.data.to(torch.float32)
213
+ model.apply(_convert_ln)
214
+ else:
215
+ model.to(device=device)
216
+ convert_weights_to_lp(model, dtype=dtype)
217
+ elif precision in ("pure_fp16", "pure_bf16"):
218
+ dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
219
+ model.to(device=device, dtype=dtype)
220
+ else:
221
+ model.to(device=device)
222
+
223
+ pretrained_loaded = False
224
+ if pretrained:
225
+ checkpoint_path = ''
226
+ pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
227
+ if pretrained_cfg:
228
+ checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
229
+ elif os.path.exists(pretrained):
230
+ checkpoint_path = pretrained
231
+
232
+ if checkpoint_path:
233
+ logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
234
+ load_checkpoint(model, checkpoint_path)
235
+ else:
236
+ error_str = (
237
+ f'Pretrained weights ({pretrained}) not found for model {model_name}.'
238
+ f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
239
+ logging.warning(error_str)
240
+ raise RuntimeError(error_str)
241
+ pretrained_loaded = True
242
+ elif has_hf_hub_prefix:
243
+ logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
244
+ load_checkpoint(model, checkpoint_path)
245
+ pretrained_loaded = True
246
+
247
+ if require_pretrained and not pretrained_loaded:
248
+ # callers of create_model_from_pretrained always expect pretrained weights
249
+ raise RuntimeError(
250
+ f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.')
251
+
252
+ # set image / mean metadata from pretrained_cfg if available, or use default
253
+ model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN
254
+ model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD
255
+
256
+ if output_dict and hasattr(model, "output_dict"):
257
+ model.output_dict = True
258
+
259
+ if jit:
260
+ model = torch.jit.script(model)
261
+
262
+ return model
263
+
264
+
265
+ def create_loss(args):
266
+ if args.distill:
267
+ return DistillClipLoss(
268
+ local_loss=args.local_loss,
269
+ gather_with_grad=args.gather_with_grad,
270
+ cache_labels=True,
271
+ rank=args.rank,
272
+ world_size=args.world_size,
273
+ use_horovod=args.horovod,
274
+ )
275
+ elif "coca" in args.model.lower():
276
+ return CoCaLoss(
277
+ caption_loss_weight=args.coca_caption_loss_weight,
278
+ clip_loss_weight=args.coca_contrastive_loss_weight,
279
+ local_loss=args.local_loss,
280
+ gather_with_grad=args.gather_with_grad,
281
+ cache_labels=True,
282
+ rank=args.rank,
283
+ world_size=args.world_size,
284
+ use_horovod=args.horovod,
285
+ )
286
+ return ClipLoss(
287
+ local_loss=args.local_loss,
288
+ gather_with_grad=args.gather_with_grad,
289
+ cache_labels=True,
290
+ rank=args.rank,
291
+ world_size=args.world_size,
292
+ use_horovod=args.horovod,
293
+ )
294
+
295
+
296
+ def create_model_and_transforms(
297
+ model_name: str,
298
+ pretrained: Optional[str] = None,
299
+ precision: str = 'fp32',
300
+ device: Union[str, torch.device] = 'cpu',
301
+ jit: bool = False,
302
+ force_quick_gelu: bool = False,
303
+ force_custom_text: bool = False,
304
+ force_patch_dropout: Optional[float] = None,
305
+ force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
306
+ pretrained_image: bool = False,
307
+ pretrained_hf: bool = True,
308
+ image_mean: Optional[Tuple[float, ...]] = None,
309
+ image_std: Optional[Tuple[float, ...]] = None,
310
+ aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
311
+ cache_dir: Optional[str] = None,
312
+ output_dict: Optional[bool] = None,
313
+ ):
314
+
315
+
316
+ model = create_model(
317
+ model_name,
318
+ pretrained,
319
+ precision=precision,
320
+ device=device,
321
+ jit=jit,
322
+ force_quick_gelu=force_quick_gelu,
323
+ force_custom_text=force_custom_text,
324
+ force_patch_dropout=force_patch_dropout,
325
+ force_image_size=force_image_size,
326
+ pretrained_image=pretrained_image,
327
+ pretrained_hf=pretrained_hf,
328
+ cache_dir=cache_dir,
329
+ output_dict=output_dict,
330
+ )
331
+
332
+ image_mean = image_mean or getattr(model.visual, 'image_mean', None)
333
+ image_std = image_std or getattr(model.visual, 'image_std', None)
334
+ preprocess_train = image_transform(
335
+ model.visual.image_size,
336
+ is_train=True,
337
+ mean=image_mean,
338
+ std=image_std,
339
+ aug_cfg=aug_cfg,
340
+ )
341
+ preprocess_val = image_transform(
342
+ model.visual.image_size,
343
+ is_train=False,
344
+ mean=image_mean,
345
+ std=image_std,
346
+ )
347
+
348
+ return model, preprocess_train, preprocess_val
349
+
350
+
351
+ def create_model_from_pretrained(
352
+ model_name: str,
353
+ pretrained: Optional[str] = None,
354
+ precision: str = 'fp32',
355
+ device: Union[str, torch.device] = 'cpu',
356
+ jit: bool = False,
357
+ force_quick_gelu: bool = False,
358
+ force_custom_text: bool = False,
359
+ force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
360
+ return_transform: bool = True,
361
+ image_mean: Optional[Tuple[float, ...]] = None,
362
+ image_std: Optional[Tuple[float, ...]] = None,
363
+ cache_dir: Optional[str] = None,
364
+ ):
365
+ model = create_model(
366
+ model_name,
367
+ pretrained,
368
+ precision=precision,
369
+ device=device,
370
+ jit=jit,
371
+ force_quick_gelu=force_quick_gelu,
372
+ force_custom_text=force_custom_text,
373
+ force_image_size=force_image_size,
374
+ cache_dir=cache_dir,
375
+ require_pretrained=True,
376
+ )
377
+
378
+ if not return_transform:
379
+ return model
380
+
381
+ image_mean = image_mean or getattr(model.visual, 'image_mean', None)
382
+ image_std = image_std or getattr(model.visual, 'image_std', None)
383
+ preprocess = image_transform(
384
+ model.visual.image_size,
385
+ is_train=False,
386
+ mean=image_mean,
387
+ std=image_std,
388
+ )
389
+
390
+ return model, preprocess
imagenet_classes.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ OPENAI_IMAGENET_TEMPLATES = (
2
+ lambda c: f'a bad photo of a {c}.',
3
+ lambda c: f'a photo of many {c}.',
4
+ lambda c: f'a sculpture of a {c}.',
5
+ lambda c: f'a photo of the hard to see {c}.',
6
+ lambda c: f'a low resolution photo of the {c}.',
7
+ lambda c: f'a rendering of a {c}.',
8
+ lambda c: f'graffiti of a {c}.',
9
+ lambda c: f'a bad photo of the {c}.',
10
+ lambda c: f'a cropped photo of the {c}.',
11
+ lambda c: f'a tattoo of a {c}.',
12
+ lambda c: f'the embroidered {c}.',
13
+ lambda c: f'a photo of a hard to see {c}.',
14
+ lambda c: f'a bright photo of a {c}.',
15
+ lambda c: f'a photo of a clean {c}.',
16
+ lambda c: f'a photo of a dirty {c}.',
17
+ lambda c: f'a dark photo of the {c}.',
18
+ lambda c: f'a drawing of a {c}.',
19
+ lambda c: f'a photo of my {c}.',
20
+ lambda c: f'the plastic {c}.',
21
+ lambda c: f'a photo of the cool {c}.',
22
+ lambda c: f'a close-up photo of a {c}.',
23
+ lambda c: f'a black and white photo of the {c}.',
24
+ lambda c: f'a painting of the {c}.',
25
+ lambda c: f'a painting of a {c}.',
26
+ lambda c: f'a pixelated photo of the {c}.',
27
+ lambda c: f'a sculpture of the {c}.',
28
+ lambda c: f'a bright photo of the {c}.',
29
+ lambda c: f'a cropped photo of a {c}.',
30
+ lambda c: f'a plastic {c}.',
31
+ lambda c: f'a photo of the dirty {c}.',
32
+ lambda c: f'a jpeg corrupted photo of a {c}.',
33
+ lambda c: f'a blurry photo of the {c}.',
34
+ lambda c: f'a photo of the {c}.',
35
+ lambda c: f'a good photo of the {c}.',
36
+ lambda c: f'a rendering of the {c}.',
37
+ lambda c: f'a {c} in a video game.',
38
+ lambda c: f'a photo of one {c}.',
39
+ lambda c: f'a doodle of a {c}.',
40
+ lambda c: f'a close-up photo of the {c}.',
41
+ lambda c: f'a photo of a {c}.',
42
+ lambda c: f'the origami {c}.',
43
+ lambda c: f'the {c} in a video game.',
44
+ lambda c: f'a sketch of a {c}.',
45
+ lambda c: f'a doodle of the {c}.',
46
+ lambda c: f'a origami {c}.',
47
+ lambda c: f'a low resolution photo of a {c}.',
48
+ lambda c: f'the toy {c}.',
49
+ lambda c: f'a rendition of the {c}.',
50
+ lambda c: f'a photo of the clean {c}.',
51
+ lambda c: f'a photo of a large {c}.',
52
+ lambda c: f'a rendition of a {c}.',
53
+ lambda c: f'a photo of a nice {c}.',
54
+ lambda c: f'a photo of a weird {c}.',
55
+ lambda c: f'a blurry photo of a {c}.',
56
+ lambda c: f'a cartoon {c}.',
57
+ lambda c: f'art of a {c}.',
58
+ lambda c: f'a sketch of the {c}.',
59
+ lambda c: f'a embroidered {c}.',
60
+ lambda c: f'a pixelated photo of a {c}.',
61
+ lambda c: f'itap of the {c}.',
62
+ lambda c: f'a jpeg corrupted photo of the {c}.',
63
+ lambda c: f'a good photo of a {c}.',
64
+ lambda c: f'a plushie {c}.',
65
+ lambda c: f'a photo of the nice {c}.',
66
+ lambda c: f'a photo of the small {c}.',
67
+ lambda c: f'a photo of the weird {c}.',
68
+ lambda c: f'the cartoon {c}.',
69
+ lambda c: f'art of the {c}.',
70
+ lambda c: f'a drawing of the {c}.',
71
+ lambda c: f'a photo of the large {c}.',
72
+ lambda c: f'a black and white photo of a {c}.',
73
+ lambda c: f'the plushie {c}.',
74
+ lambda c: f'a dark photo of a {c}.',
75
+ lambda c: f'itap of a {c}.',
76
+ lambda c: f'graffiti of the {c}.',
77
+ lambda c: f'a toy {c}.',
78
+ lambda c: f'itap of my {c}.',
79
+ lambda c: f'a photo of a cool {c}.',
80
+ lambda c: f'a photo of a small {c}.',
81
+ lambda c: f'a tattoo of the {c}.',
82
+ )
83
+
84
+
85
+ IMAGENET_CLASSNAMES = ["tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray", "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco", "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper", "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander", "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog", "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin", "box turtle", "banded gecko", "green iguana", "Carolina anole", "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard", "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile", "American alligator", "triceratops", "worm snake", "ring-necked snake", "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake", "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra", "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake", "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider", "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider", "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl", "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet", "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck", "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby", "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch", "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab", "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab", "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron", "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot", "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher", "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion", "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel", "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle", "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound", "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound", "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound", "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier", "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier", "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier", "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier", "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer", "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier", "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier", "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever", "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla", "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel", "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel", "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard", "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie", "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann", "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog", "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff", "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky", "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog", "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon", "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle", "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf", "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox", "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat", "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger", "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose", "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle", "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper", "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper", "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly", "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly", "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit", "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse", "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison", "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)", "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat", "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan", "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque", "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin", "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey", "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda", "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish", "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown", "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance", "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle", "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo", "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel", "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel", "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)", "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini", "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet", "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra", "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest", "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe", "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton", "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran", "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw", "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking", "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker", "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard", "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot", "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed", "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer", "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table", "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig", "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar", "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder", "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute", "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed", "freight car", "French horn", "frying pan", "fur coat", "garbage truck", "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola", "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine", "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer", "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet", "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar", "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep", "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat", "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library", "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion", "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag", "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask", "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone", "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile", "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor", "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa", "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail", "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina", "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart", "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush", "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench", "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case", "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube", "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball", "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag", "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho", "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug", "printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill", "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel", "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator", "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser", "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal", "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard", "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store", "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap", "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door", "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock", "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater", "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight", "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf", "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa", "submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge", "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe", "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball", "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof", "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store", "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod", "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard", "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling", "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball", "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink", "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle", "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing", "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website", "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu", "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette", "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli", "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber", "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange", "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate", "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito", "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef", "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player", "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn", "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom", "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"]
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
misc.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from itertools import repeat
2
+ import collections.abc
3
+
4
+ import torch
5
+ from torch import nn as nn
6
+ from torchvision.ops.misc import FrozenBatchNorm2d
7
+
8
+
9
+ def freeze_batch_norm_2d(module, module_match={}, name=''):
10
+ """
11
+ Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
12
+ itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
13
+ returned. Otherwise, the module is walked recursively and submodules are converted in place.
14
+
15
+ Args:
16
+ module (torch.nn.Module): Any PyTorch module.
17
+ module_match (dict): Dictionary of full module names to freeze (all if empty)
18
+ name (str): Full module name (prefix)
19
+
20
+ Returns:
21
+ torch.nn.Module: Resulting module
22
+
23
+ Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
24
+ """
25
+ res = module
26
+ is_match = True
27
+ if module_match:
28
+ is_match = name in module_match
29
+ if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)):
30
+ res = FrozenBatchNorm2d(module.num_features)
31
+ res.num_features = module.num_features
32
+ res.affine = module.affine
33
+ if module.affine:
34
+ res.weight.data = module.weight.data.clone().detach()
35
+ res.bias.data = module.bias.data.clone().detach()
36
+ res.running_mean.data = module.running_mean.data
37
+ res.running_var.data = module.running_var.data
38
+ res.eps = module.eps
39
+ else:
40
+ for child_name, child in module.named_children():
41
+ full_child_name = '.'.join([name, child_name]) if name else child_name
42
+ new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
43
+ if new_child is not child:
44
+ res.add_module(child_name, new_child)
45
+ return res
46
+
47
+
48
+ # From PyTorch internals
49
+ def _ntuple(n):
50
+ def parse(x):
51
+ if isinstance(x, collections.abc.Iterable):
52
+ return x
53
+ return tuple(repeat(x, n))
54
+ return parse
55
+
56
+
57
+ to_1tuple = _ntuple(1)
58
+ to_2tuple = _ntuple(2)
59
+ to_3tuple = _ntuple(3)
60
+ to_4tuple = _ntuple(4)
61
+ to_ntuple = lambda n, x: _ntuple(n)(x)
62
+
63
+ # Replaces all linear layers with linear_replacement
64
+ # TODO: add int8 support for other linear layers including attn and convnets
65
+ def replace_linear(model, linear_replacement, include_modules=['c_fc', 'c_proj'], copy_weights=True):
66
+ for name, module in model.named_children():
67
+ if len(list(module.children())) > 0:
68
+ replace_linear(module, linear_replacement, include_modules, copy_weights)
69
+
70
+ if isinstance(module, torch.nn.Linear) and name in include_modules:
71
+ old_module = model._modules[name]
72
+ model._modules[name] = linear_replacement(
73
+ module.in_features,
74
+ module.out_features,
75
+ module.bias is not None,
76
+ )
77
+ if copy_weights:
78
+ model._modules[name].weight.data.copy_(old_module.weight.data)
79
+ if model._modules[name].bias is not None:
80
+ model._modules[name].bias.data.copy_(old_module.bias)
81
+
82
+ return model
83
+
84
+ def convert_int8_model_to_inference_mode(model):
85
+ for m in model.modules():
86
+ if hasattr(m, 'prepare_for_eval'):
87
+ int8_original_dtype = m.weight.dtype
88
+ m.prepare_for_eval()
89
+ m.int8_original_dtype = int8_original_dtype
90
+
91
+
92
+ def accuracy(output, target, topk=(1,)):
93
+ """
94
+ Compute top-k accuracy
95
+
96
+ output: torch.Tensor
97
+ shape (N, C) where N is the number of examples, C the number of classes.
98
+ these are the logits.
99
+
100
+ target: torch.Tensor
101
+ shape (N,) where N is the number of examples. Groundtruth class id of each example.
102
+
103
+ topk: tuple
104
+ which topk to compute, e.g., topk=(1,5) will compute top-1 and top-5 accuracies
105
+
106
+ Returns
107
+ -------
108
+
109
+ list of top-k accuracies in the same order as `topk`
110
+ """
111
+ pred = output.topk(max(topk), 1, True, True)[1].t()
112
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
113
+ n = len(target)
114
+ return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) / n for k in topk]
model.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ CLIP Model
2
+
3
+ Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
+ """
5
+ from dataclasses import dataclass
6
+ import logging
7
+ import math
8
+ from typing import Optional, Tuple, Union, Text
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from torch import nn
14
+ from torch.utils.checkpoint import checkpoint
15
+
16
+
17
+ from modified_resnet import ModifiedResNet
18
+ from timm_model import TimmModel
19
+ from transformer import LayerNorm, QuickGELU, VisionTransformer, TextTransformer, Attention
20
+ from misc import to_2tuple
21
+
22
+
23
+
24
+ @dataclass
25
+ class CLIPVisionCfg:
26
+ layers: Union[Tuple[int, int, int, int], int] = 12
27
+ width: int = 768
28
+ head_width: int = 64
29
+ mlp_ratio: float = 4.0
30
+ patch_size: int = 16
31
+ image_size: Union[Tuple[int, int], int] = 224
32
+
33
+ ls_init_value: Optional[float] = None # layer scale initial value
34
+ patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
35
+ input_patchnorm: bool = False # whether to use dual patchnorm - would only apply the input layernorm on each patch, as post-layernorm already exist in original clip vit design
36
+ global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
37
+ attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer
38
+ n_queries: int = 256 # n_queries for attentional pooler
39
+ attn_pooler_heads: int = 8 # n heads for attentional_pooling
40
+ output_tokens: bool = False
41
+
42
+ timm_model_name: str = None # a valid model name overrides layers, width, patch_size
43
+ timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
44
+ timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
45
+ timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
46
+ timm_proj_bias: bool = False # enable bias final projection
47
+ timm_drop: float = 0. # head dropout
48
+ timm_drop_path: Optional[float] = None # backbone stochastic depth
49
+
50
+
51
+
52
+
53
+ def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
54
+ """Convert applicable model parameters to low-precision (bf16 or fp16)"""
55
+
56
+ def _convert_weights(l):
57
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
58
+ l.weight.data = l.weight.data.to(dtype)
59
+ if l.bias is not None:
60
+ l.bias.data = l.bias.data.to(dtype)
61
+
62
+ if isinstance(l, (nn.MultiheadAttention, Attention)):
63
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
64
+ tensor = getattr(l, attr)
65
+ if tensor is not None:
66
+ tensor.data = tensor.data.to(dtype)
67
+
68
+ if isinstance(l, (CLIP, TextTransformer)):
69
+ # convert text nn.Parameter projections
70
+ attr = getattr(l, "text_projection", None)
71
+ if attr is not None:
72
+ attr.data = attr.data.to(dtype)
73
+
74
+ if isinstance(l, VisionTransformer):
75
+ # convert vision nn.Parameter projections
76
+ attr = getattr(l, "proj", None)
77
+ if attr is not None:
78
+ attr.data = attr.data.to(dtype)
79
+
80
+ model.apply(_convert_weights)
81
+
82
+ convert_weights_to_fp16 = convert_weights_to_lp # backwards compat
83
+
84
+
85
+ @dataclass
86
+ class CLIPTextCfg:
87
+ context_length: int = 77
88
+ vocab_size: int = 49408
89
+ width: int = 512
90
+ heads: int = 8
91
+ layers: int = 12
92
+ ls_init_value: Optional[float] = None # layer scale initial value
93
+ hf_model_name: str = None
94
+ hf_tokenizer_name: str = None
95
+ hf_model_pretrained: bool = True
96
+ proj: str = 'mlp'
97
+ pooler_type: str = 'mean_pooler'
98
+ embed_cls: bool = False
99
+ pad_id: int = 0
100
+ output_tokens: bool = False
101
+
102
+
103
+ def get_cast_dtype(precision: str):
104
+ cast_dtype = None
105
+ if precision == 'bf16':
106
+ cast_dtype = torch.bfloat16
107
+ elif precision == 'fp16':
108
+ cast_dtype = torch.float16
109
+ return cast_dtype
110
+
111
+
112
+ def get_input_dtype(precision: str):
113
+ input_dtype = None
114
+ if precision in ('bf16', 'pure_bf16'):
115
+ input_dtype = torch.bfloat16
116
+ elif precision in ('fp16', 'pure_fp16'):
117
+ input_dtype = torch.float16
118
+ return input_dtype
119
+
120
+
121
+ def _build_vision_tower(
122
+ embed_dim: int,
123
+ vision_cfg: CLIPVisionCfg,
124
+ quick_gelu: bool = False,
125
+ cast_dtype: Optional[torch.dtype] = None,
126
+ ):
127
+ if isinstance(vision_cfg, dict):
128
+ vision_cfg = CLIPVisionCfg(**vision_cfg)
129
+
130
+ # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
131
+ # memory efficient in recent PyTorch releases (>= 1.10).
132
+ # NOTE: timm models always use native GELU regardless of quick_gelu flag.
133
+ act_layer = QuickGELU if quick_gelu else nn.GELU
134
+
135
+ if vision_cfg.timm_model_name:
136
+ visual = TimmModel(
137
+ vision_cfg.timm_model_name,
138
+ pretrained=vision_cfg.timm_model_pretrained,
139
+ pool=vision_cfg.timm_pool,
140
+ proj=vision_cfg.timm_proj,
141
+ proj_bias=vision_cfg.timm_proj_bias,
142
+ drop=vision_cfg.timm_drop,
143
+ drop_path=vision_cfg.timm_drop_path,
144
+ patch_drop=vision_cfg.patch_dropout if vision_cfg.patch_dropout > 0 else None,
145
+ embed_dim=embed_dim,
146
+ image_size=vision_cfg.image_size,
147
+ )
148
+ elif isinstance(vision_cfg.layers, (tuple, list)):
149
+ vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
150
+ visual = ModifiedResNet(
151
+ layers=vision_cfg.layers,
152
+ output_dim=embed_dim,
153
+ heads=vision_heads,
154
+ image_size=vision_cfg.image_size,
155
+ width=vision_cfg.width,
156
+ )
157
+ else:
158
+ vision_heads = vision_cfg.width // vision_cfg.head_width
159
+ norm_layer = LayerNorm
160
+ visual = VisionTransformer(
161
+ image_size=vision_cfg.image_size,
162
+ patch_size=vision_cfg.patch_size,
163
+ width=vision_cfg.width,
164
+ layers=vision_cfg.layers,
165
+ heads=vision_heads,
166
+ mlp_ratio=vision_cfg.mlp_ratio,
167
+ ls_init_value=vision_cfg.ls_init_value,
168
+ patch_dropout=vision_cfg.patch_dropout,
169
+ input_patchnorm=vision_cfg.input_patchnorm,
170
+ global_average_pool=vision_cfg.global_average_pool,
171
+ attentional_pool=vision_cfg.attentional_pool,
172
+ n_queries=vision_cfg.n_queries,
173
+ attn_pooler_heads=vision_cfg.attn_pooler_heads,
174
+ output_tokens=vision_cfg.output_tokens,
175
+ output_dim=embed_dim,
176
+ act_layer=act_layer,
177
+ norm_layer=norm_layer,
178
+ )
179
+
180
+ return visual
181
+
182
+
183
+ def _build_text_tower(
184
+ embed_dim: int,
185
+ text_cfg: CLIPTextCfg,
186
+ quick_gelu: bool = False,
187
+ cast_dtype: Optional[torch.dtype] = None,
188
+ ):
189
+ if isinstance(text_cfg, dict):
190
+ text_cfg = CLIPTextCfg(**text_cfg)
191
+
192
+ if text_cfg.hf_model_name:
193
+ from hf_model import HFTextEncoder
194
+ text = HFTextEncoder(
195
+ text_cfg.hf_model_name,
196
+ output_dim=embed_dim,
197
+ proj=text_cfg.proj,
198
+ pooler_type=text_cfg.pooler_type,
199
+ pretrained=text_cfg.hf_model_pretrained,
200
+ output_tokens=text_cfg.output_tokens,
201
+ )
202
+ else:
203
+ act_layer = QuickGELU if quick_gelu else nn.GELU
204
+ norm_layer = LayerNorm
205
+
206
+ text = TextTransformer(
207
+ context_length=text_cfg.context_length,
208
+ vocab_size=text_cfg.vocab_size,
209
+ width=text_cfg.width,
210
+ heads=text_cfg.heads,
211
+ layers=text_cfg.layers,
212
+ ls_init_value=text_cfg.ls_init_value,
213
+ output_dim=embed_dim,
214
+ embed_cls=text_cfg.embed_cls,
215
+ output_tokens=text_cfg.output_tokens,
216
+ pad_id=text_cfg.pad_id,
217
+ act_layer=act_layer,
218
+ norm_layer=norm_layer,
219
+ )
220
+ return text
221
+
222
+
223
+ class CLIP(nn.Module):
224
+ """
225
+ _VITL14_336 = dict(
226
+ openai=_pcfg(
227
+ url="https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
228
+ hf_hub="timm/vit_large_patch14_clip_336.openai/",
229
+ quick_gelu=True,
230
+ ),
231
+ )
232
+
233
+ """
234
+ output_dict: torch.jit.Final[bool]
235
+
236
+ def __init__(
237
+ self,
238
+ embed_dim: int,
239
+ vision_cfg: CLIPVisionCfg,
240
+ text_cfg: CLIPTextCfg,
241
+ quick_gelu: bool = False,
242
+ cast_dtype: Optional[torch.dtype] = None,
243
+ output_dict: bool = False,
244
+ ):
245
+ super().__init__()
246
+ self.output_dict = output_dict
247
+ self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
248
+ print(f"Building vision tower with config: {vision_cfg}")
249
+
250
+ print(f"Currently text tower is removed, using only image encoder for feature extraction")
251
+ do_use = False
252
+ if do_use:
253
+ text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
254
+ self.transformer = text.transformer
255
+ self.context_length = text.context_length
256
+ self.vocab_size = text.vocab_size
257
+ self.token_embedding = text.token_embedding
258
+ self.positional_embedding = text.positional_embedding
259
+ self.ln_final = text.ln_final
260
+ self.text_projection = text.text_projection
261
+ self.register_buffer('attn_mask', text.attn_mask, persistent=False)
262
+
263
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
264
+
265
+ self.num_register_tokens = 0
266
+ self.neuron_dict = None
267
+
268
+ @torch.jit.ignore
269
+ def set_grad_checkpointing(self, enable=True):
270
+ self.visual.set_grad_checkpointing(enable)
271
+ self.transformer.grad_checkpointing = enable
272
+
273
+ def encode_image(self, image, normalize: bool = False, attn_method: Text = 'direct', num_register_tokens = None, neuron_dict=None):
274
+ if num_register_tokens is None and neuron_dict is None:
275
+ num_register_tokens = self.num_register_tokens
276
+ neuron_dict = self.neuron_dict
277
+
278
+
279
+ features = self.visual(image, attn_method=attn_method, num_register_tokens=num_register_tokens, neuron_dict=neuron_dict)
280
+ return F.normalize(features, dim=-1) if normalize else features
281
+
282
+ def encode_text(self, text, normalize: bool = False):
283
+ cast_dtype = self.transformer.get_cast_dtype()
284
+
285
+ x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
286
+
287
+ x = x + self.positional_embedding.to(cast_dtype)
288
+ # x = x.permute(1, 0, 2) # NLD -> LND
289
+ x = self.transformer(x, attn_mask=self.attn_mask)
290
+ # x = x.permute(1, 0, 2) # LND -> NLD
291
+ x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
292
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
293
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
294
+ return F.normalize(x, dim=-1) if normalize else x
295
+
296
+ def forward(
297
+ self,
298
+ image: Optional[torch.Tensor] = None,
299
+ text: Optional[torch.Tensor] = None,
300
+ num_register_tokens = None,
301
+ neuron_dict=None
302
+
303
+ ):
304
+
305
+ if num_register_tokens is None and neuron_dict is None:
306
+ num_register_tokens = self.num_register_tokens
307
+ neuron_dict = self.neuron_dict
308
+
309
+
310
+ image_features = self.encode_image(image, num_register_tokens=num_register_tokens, neuron_dict=neuron_dict, normalize=True) if image is not None else None
311
+ text_features = self.encode_text(text, normalize=True) if text is not None else None
312
+ if self.output_dict:
313
+ return {
314
+ "image_features": image_features,
315
+ "text_features": text_features,
316
+ "logit_scale": self.logit_scale.exp()
317
+ }
318
+ return image_features, text_features, self.logit_scale.exp()
319
+
320
+
321
+ # used to maintain checkpoint compatibility
322
+ def convert_to_custom_text_state_dict(state_dict: dict):
323
+ if 'text_projection' in state_dict:
324
+ # old format state_dict, move text tower -> .text
325
+ new_state_dict = {}
326
+ for k, v in state_dict.items():
327
+ if any(k.startswith(p) for p in (
328
+ 'text_projection',
329
+ 'positional_embedding',
330
+ 'token_embedding',
331
+ 'transformer',
332
+ 'ln_final',
333
+ )):
334
+ k = 'text.' + k
335
+ new_state_dict[k] = v
336
+ return new_state_dict
337
+ return state_dict
338
+
339
+
340
+ def build_model_from_openai_state_dict(
341
+ state_dict: dict,
342
+ quick_gelu=True,
343
+ cast_dtype=torch.float16,
344
+ ):
345
+ vit = "visual.proj" in state_dict
346
+
347
+ if vit:
348
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
349
+ vision_layers = len(
350
+ [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
351
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
352
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
353
+ image_size = vision_patch_size * grid_size
354
+ else:
355
+ counts: list = [
356
+ len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
357
+ vision_layers = tuple(counts)
358
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
359
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
360
+ vision_patch_size = None
361
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
362
+ image_size = output_width * 32
363
+
364
+ embed_dim = state_dict["text_projection"].shape[1]
365
+ context_length = state_dict["positional_embedding"].shape[0]
366
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
367
+ transformer_width = state_dict["ln_final.weight"].shape[0]
368
+ transformer_heads = transformer_width // 64
369
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
370
+
371
+ vision_cfg = CLIPVisionCfg(
372
+ layers=vision_layers,
373
+ width=vision_width,
374
+ patch_size=vision_patch_size,
375
+ image_size=image_size,
376
+ )
377
+ text_cfg = CLIPTextCfg(
378
+ context_length=context_length,
379
+ vocab_size=vocab_size,
380
+ width=transformer_width,
381
+ heads=transformer_heads,
382
+ layers=transformer_layers,
383
+ )
384
+ model = CLIP(
385
+ embed_dim,
386
+ vision_cfg=vision_cfg,
387
+ text_cfg=text_cfg,
388
+ quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU
389
+ cast_dtype=cast_dtype,
390
+ )
391
+
392
+ for key in ["input_resolution", "context_length", "vocab_size"]:
393
+ state_dict.pop(key, None)
394
+
395
+ convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16
396
+ model.load_state_dict(state_dict)
397
+ return model.eval()
398
+
399
+
400
+ def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True):
401
+ # Rescale the grid of position embeddings when loading from state_dict
402
+ old_pos_embed = state_dict.get('visual.positional_embedding', None)
403
+ if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
404
+ return
405
+ grid_size = to_2tuple(model.visual.grid_size)
406
+ extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
407
+ new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
408
+ if new_seq_len == old_pos_embed.shape[0]:
409
+ return
410
+
411
+ if extra_tokens:
412
+ pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
413
+ else:
414
+ pos_emb_tok, pos_emb_img = None, old_pos_embed
415
+ old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
416
+
417
+ logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
418
+ pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
419
+ pos_emb_img = F.interpolate(
420
+ pos_emb_img,
421
+ size=grid_size,
422
+ mode=interpolation,
423
+ antialias=antialias,
424
+ align_corners=False,
425
+ )
426
+ pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
427
+ if pos_emb_tok is not None:
428
+ new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
429
+ else:
430
+ new_pos_embed = pos_emb_img
431
+ state_dict['visual.positional_embedding'] = new_pos_embed
model_sanity_check.ipynb ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "ba945813",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "%load_ext autoreload\n",
11
+ "%autoreload 2"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": 2,
17
+ "id": "e7cec94e",
18
+ "metadata": {},
19
+ "outputs": [],
20
+ "source": [
21
+ "\n",
22
+ "import os, json, math, torch, tqdm\n",
23
+ "from pathlib import Path\n",
24
+ "from torchvision import transforms\n",
25
+ "from torchvision.datasets import ImageFolder\n",
26
+ "from torch.utils.data import DataLoader\n",
27
+ "from transformers import CLIPProcessor, CLIPModel\n",
28
+ "import os\n",
29
+ "import itertools\n",
30
+ "\n",
31
+ "import torch\n",
32
+ "import numpy as np\n",
33
+ "\n",
34
+ "import transformers\n",
35
+ "from transformers import AutoModel, AutoProcessor, CLIPForImageClassification, AutoConfig, AutoTokenizer\n",
36
+ "from torchvision import transforms\n",
37
+ "from torchvision.datasets import ImageNet\n",
38
+ "from torch.utils.data import Subset\n",
39
+ "from tqdm import tqdm\n",
40
+ "from PIL import Image\n",
41
+ "import matplotlib.ticker as mticker\n",
42
+ "import matplotlib.pyplot as plt\n",
43
+ "from mpl_toolkits.mplot3d import Axes3D # noqa: F401 – 3D 기능 활성화\n",
44
+ "import inspect\n",
45
+ "import torch.nn.functional as F\n",
46
+ "import torchvision.transforms.functional as VF\n",
47
+ "import tqdm\n",
48
+ "\n",
49
+ "\n",
50
+ "from functools import partial\n",
51
+ "\n",
52
+ "from torchvision import transforms\n",
53
+ "from torchvision.transforms import InterpolationMode\n",
54
+ "\n",
55
+ "from tqdm import tqdm\n",
56
+ "\n",
57
+ "import yaml\n",
58
+ "from pathlib import Path\n",
59
+ "\n",
60
+ "import sys\n",
61
+ "import os\n",
62
+ "\n",
63
+ "from imagenet_classes import *\n"
64
+ ]
65
+ },
66
+ {
67
+ "cell_type": "code",
68
+ "execution_count": null,
69
+ "id": "b4c7a750",
70
+ "metadata": {},
71
+ "outputs": [
72
+ {
73
+ "name": "stdout",
74
+ "output_type": "stream",
75
+ "text": [
76
+ "Pretrained path from config: /workspace/code/clipL336_TTR\n",
77
+ "✓ Added '/workspace/code/clipL336_TTR' to Python path.\n",
78
+ "✓ Successfully imported 'model' from '/workspace/code/clipL336_TTR'\n",
79
+ "Building vision tower with config: CLIPVisionCfg(layers=24, width=1024, head_width=64, mlp_ratio=4.0, patch_size=14, image_size=336, ls_init_value=None, patch_dropout=0.0, input_patchnorm=False, global_average_pool=False, attentional_pool=False, n_queries=256, attn_pooler_heads=8, output_tokens=False, timm_model_name=None, timm_model_pretrained=False, timm_pool='avg', timm_proj='linear', timm_proj_bias=False, timm_drop=0.0, timm_drop_path=None)\n",
80
+ "✓ Added '/workspace/data/cache/huggingface/modules/transformers_modules/clipL336_TTR' to Python path.\n",
81
+ "✓ Successfully imported 'tokenizer' from '/workspace/data/cache/huggingface/modules/transformers_modules/clipL336_TTR'\n",
82
+ "Custom CLIP model loaded successfully!\n"
83
+ ]
84
+ }
85
+ ],
86
+ "source": [
87
+ "# 문제의 원인이 text encoder가 고장이 나 있었다..\n",
88
+ "device = \"cuda:7\"\n",
89
+ "model_path = \"/workspace/code/clipL336_TTR\"\n",
90
+ "\n",
91
+ "exp_cfg = AutoConfig.from_pretrained(\"/workspace/code/clipL336_TTR\", trust_remote_code=True)\n",
92
+ "exp_cfg.pretrained_path = model_path \n",
93
+ "# model = AutoModel.from_pretrained(model_path, trust_remote_code=True, local_files_only=True)\n",
94
+ "model = AutoModel.from_pretrained(pretrained_model_name_or_path=model_path, config=exp_cfg, trust_remote_code=True, local_files_only=True)\n",
95
+ "# 여기 load 되었는 지 확인할 필요 있음\n",
96
+ "model = model.to(device)\n",
97
+ "preprocessor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True, local_files_only=True)\n",
98
+ "tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, local_files_only=True)\n",
99
+ "# tokenizer랑 preprocessor 가져오기\n",
100
+ "\n",
101
+ "clip_transform = lambda image: preprocessor.image_processor(image, return_tensors=\"pt\")['pixel_values'].squeeze(0) # 와 이렇게 활용할 방법은 생각도 못했네\n",
102
+ "model_clip = AutoModel.from_pretrained(\"openai/clip-vit-large-patch14-336\").to(device).half()"
103
+ ]
104
+ },
105
+ {
106
+ "cell_type": "code",
107
+ "execution_count": 4,
108
+ "id": "ed3cbfdc",
109
+ "metadata": {},
110
+ "outputs": [
111
+ {
112
+ "name": "stderr",
113
+ "output_type": "stream",
114
+ "text": [
115
+ "100%|██████████| 1000/1000 [00:23<00:00, 41.71it/s]"
116
+ ]
117
+ },
118
+ {
119
+ "name": "stdout",
120
+ "output_type": "stream",
121
+ "text": [
122
+ "Built text features: torch.Size([768, 1000])\n"
123
+ ]
124
+ },
125
+ {
126
+ "name": "stderr",
127
+ "output_type": "stream",
128
+ "text": [
129
+ "\n"
130
+ ]
131
+ }
132
+ ],
133
+ "source": [
134
+ "# langauge head\n",
135
+ "### zeroshot head construction (text encoding) ###\n",
136
+ "with torch.no_grad():\n",
137
+ " zeroshot_weight = []\n",
138
+ " for classname in tqdm(IMAGENET_CLASSNAMES):\n",
139
+ " texts = [template(classname) for template in OPENAI_IMAGENET_TEMPLATES]\n",
140
+ " text_inputs = preprocessor(text=texts, return_tensors=\"pt\", padding=\"max_length\").to(device)\n",
141
+ " # text_inputs = model.tokenize(texts).to(device)\n",
142
+ " # text_features = model.encode_text(text_inputs.input_ids)\n",
143
+ " text_features = model_clip.get_text_features(**text_inputs)\n",
144
+ " text_feature = F.normalize(text_features, dim=-1).mean(dim=0)\n",
145
+ " # text_feature = text_features.mean(dim=0)\n",
146
+ " text_feature = text_feature / text_feature.norm()\n",
147
+ " zeroshot_weight.append(text_feature)\n",
148
+ " \n",
149
+ " text_features = torch.stack(zeroshot_weight, dim=1).to(device)\n",
150
+ "print(\"Built text features:\", text_features.shape)"
151
+ ]
152
+ },
153
+ {
154
+ "cell_type": "code",
155
+ "execution_count": 10,
156
+ "id": "e1bd37d1",
157
+ "metadata": {},
158
+ "outputs": [],
159
+ "source": [
160
+ "torch.save(text_features, \"./zeroshot_classifier.pt\")"
161
+ ]
162
+ },
163
+ {
164
+ "cell_type": "code",
165
+ "execution_count": 5,
166
+ "id": "dbfeaedf",
167
+ "metadata": {},
168
+ "outputs": [],
169
+ "source": [
170
+ "imagenet_dataset = ImageNet(root='/workspace/data/imagenet', split='val', transform=clip_transform)\n",
171
+ "eval_loader = torch.utils.data.DataLoader(imagenet_dataset, batch_size=128, num_workers=16, pin_memory=False, shuffle=False)"
172
+ ]
173
+ },
174
+ {
175
+ "cell_type": "code",
176
+ "execution_count": 6,
177
+ "id": "b0000195",
178
+ "metadata": {},
179
+ "outputs": [],
180
+ "source": [
181
+ "import numpy as np\n",
182
+ "\n",
183
+ "half = \"torch.bfloat16\"\n",
184
+ "def evaluate(model, loader, text_feats, max_samples: int | None = None):\n",
185
+ " model.eval()\n",
186
+ " top1 = top5 = n = 0\n",
187
+ " pbar = tqdm(loader, desc=\"Evaluating\", unit=\"batch\")\n",
188
+ " with torch.no_grad():\n",
189
+ " for images, labels in pbar:\n",
190
+ " if max_samples and n >= max_samples:\n",
191
+ " break\n",
192
+ " images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True)\n",
193
+ " with torch.autocast(device_type=\"cuda\"):\n",
194
+ " # 여기 test-time 가공 함수 구현 필요\n",
195
+ " feats = model.encode_image(images)\n",
196
+ "\n",
197
+ " feats = feats / feats.norm(dim=-1, keepdim=True)\n",
198
+ " logits = model.model.logit_scale.exp() * feats @ text_feats \n",
199
+ " _, pred = logits.topk(5, dim=-1)\n",
200
+ " top1 += (pred[:, :1] == labels.unsqueeze(1)).sum().item()\n",
201
+ " top5 += (pred == labels.unsqueeze(1)).sum().item()\n",
202
+ " n += images.size(0)\n",
203
+ " pbar.set_postfix(samples=n, top1=top1 / n * 100, top5=top5 / n * 100)\n",
204
+ " return top1 / n * 100, top5 / n * 100\n"
205
+ ]
206
+ },
207
+ {
208
+ "cell_type": "code",
209
+ "execution_count": 8,
210
+ "id": "8795b394",
211
+ "metadata": {},
212
+ "outputs": [
213
+ {
214
+ "name": "stderr",
215
+ "output_type": "stream",
216
+ "text": [
217
+ "Evaluating: 0%| | 0/391 [00:00<?, ?batch/s]"
218
+ ]
219
+ },
220
+ {
221
+ "name": "stderr",
222
+ "output_type": "stream",
223
+ "text": [
224
+ "Evaluating: 100%|██████████| 391/391 [10:38<00:00, 1.63s/batch, samples=5e+4, top1=74.9, top5=94.4] "
225
+ ]
226
+ },
227
+ {
228
+ "name": "stdout",
229
+ "output_type": "stream",
230
+ "text": [
231
+ "Baseline (Top‑1 / Top‑5) on 50,000 imgs: 74.87% / 94.37%\n"
232
+ ]
233
+ },
234
+ {
235
+ "name": "stderr",
236
+ "output_type": "stream",
237
+ "text": [
238
+ "\n"
239
+ ]
240
+ }
241
+ ],
242
+ "source": [
243
+ "\n",
244
+ "### baseline evaluator ###\n",
245
+ "### 이거는 지금 당장은 못 써먹는다... 미친 너무 느리다 어디서 문제지 ###\n",
246
+ "# 씨발 이번에 뭐지\n",
247
+ "# architecture define이 어딘가에서 손상 된 것으로 보인다\n",
248
+ "# 성능 reproduce...\n",
249
+ "\n",
250
+ "BASELINE_SAMPLES = 50000 # set to None for full 50 k\n",
251
+ "acc1, acc5 = evaluate(model, eval_loader, text_features, max_samples=BASELINE_SAMPLES)\n",
252
+ "print(f\"Baseline (Top‑1 / Top‑5) on {BASELINE_SAMPLES or len(imagenet_dataset):,} imgs: {acc1:.2f}% / {acc5:.2f}%\")"
253
+ ]
254
+ },
255
+ {
256
+ "cell_type": "code",
257
+ "execution_count": null,
258
+ "id": "4aa82bb4",
259
+ "metadata": {},
260
+ "outputs": [],
261
+ "source": []
262
+ }
263
+ ],
264
+ "metadata": {
265
+ "kernelspec": {
266
+ "display_name": "base",
267
+ "language": "python",
268
+ "name": "python3"
269
+ },
270
+ "language_info": {
271
+ "codemirror_mode": {
272
+ "name": "ipython",
273
+ "version": 3
274
+ },
275
+ "file_extension": ".py",
276
+ "mimetype": "text/x-python",
277
+ "name": "python",
278
+ "nbconvert_exporter": "python",
279
+ "pygments_lexer": "ipython3",
280
+ "version": "3.10.14"
281
+ }
282
+ },
283
+ "nbformat": 4,
284
+ "nbformat_minor": 5
285
+ }
modeling_custom_clip.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom CLIP Model with Register Tokens - Import Safe Version with Complete File Download
3
+ """
4
+ import transformers
5
+ import torch
6
+ import torch.nn as nn
7
+ from transformers import PreTrainedModel, PretrainedConfig
8
+ from transformers.utils import logging
9
+ from typing import Optional, Union, Tuple
10
+ import json
11
+ from pathlib import Path
12
+ import warnings
13
+ import os
14
+ import sys
15
+ import importlib.util
16
+
17
+ # Suppress all warnings during import
18
+ warnings.filterwarnings("ignore")
19
+ os.environ["TRANSFORMERS_VERBOSITY"] = "error"
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ def safe_import_from_repo(module_name: str, repo_path: str):
25
+ """
26
+ 지정된 로컬 경로(repo_path)에서 파이썬 모듈을 안전하게 임포트합니다.
27
+
28
+ Args:
29
+ module_name (str): 임포트할 모듈의 이름 (예: 'modeling_clip').
30
+ repo_path (str): 모듈이 포함된 디렉토리의 경로.
31
+
32
+ Returns:
33
+ The imported module object.
34
+
35
+ Raises:
36
+ ValueError: repo_path가 None이거나 유효한 디렉토리가 아닐 경우.
37
+ ImportError: 지정된 경로에서 모듈을 찾을 수 없을 경우.
38
+ """
39
+ # 1. repo_path가 유효한지 검사합니다.
40
+ if repo_path is None:
41
+ raise ValueError("The 'repo_path' argument cannot be None.")
42
+
43
+ # pathlib.Path 객체로 변환하여 경로를 쉽게 다룰 수 있도록 합니다.
44
+ repo_path_obj = Path(repo_path)
45
+
46
+ if not repo_path_obj.is_dir():
47
+ raise ValueError(
48
+ f"The specified repo_path does not exist or is not a directory: '{repo_path}'")
49
+
50
+ # 2. 파이썬이 모듈을 찾을 수 있도록 해당 경로를 sys.path에 추가합니다.
51
+ # resolve()를 통해 절대 경로를 사용하고, 문자열로 변환합니다.
52
+ absolute_repo_path = str(repo_path_obj.resolve())
53
+
54
+ if absolute_repo_path not in sys.path:
55
+ # sys.path의 맨 앞에 추가하여 다른 경로보다 우선적으로 탐색되도록 합니다.
56
+ sys.path.insert(0, absolute_repo_path)
57
+ print(f"✓ Added '{absolute_repo_path}' to Python path.")
58
+
59
+ # 3. `importlib`을 사용하여 모듈을 동적으로 임포트합니다.
60
+ try:
61
+ module = importlib.import_module(module_name)
62
+ print(
63
+ f"✓ Successfully imported '{module_name}' from '{absolute_repo_path}'")
64
+ return module
65
+ except ImportError:
66
+ # sys.path에 경로를 추가했음에도 임포트에 실패한 경우,
67
+ # 해당 경로에 모듈 파일(.py)이 없다는 의미입니다.
68
+ raise ImportError(
69
+ f"Module '{module_name}' not found inside the specified path: '{absolute_repo_path}'")
70
+
71
+
72
+ class CustomCLIPConfig(PretrainedConfig):
73
+ model_type = "custom_clip_with_registers"
74
+
75
+ def __init__(
76
+ self,
77
+ vision_config=None,
78
+ text_config=None,
79
+ num_register_tokens=0,
80
+ neuron_dict=None,
81
+ projection_dim=512,
82
+ logit_scale_init_value=2.6592,
83
+ **kwargs
84
+ ):
85
+ super().__init__(**kwargs)
86
+
87
+ self.vision_config = vision_config or {}
88
+ self.text_config = text_config or {}
89
+ self.num_register_tokens = num_register_tokens
90
+ self.neuron_dict = neuron_dict
91
+ self.projection_dim = projection_dim
92
+ self.logit_scale_init_value = logit_scale_init_value
93
+
94
+
95
+ class CustomCLIPModel(PreTrainedModel):
96
+ config_class = CustomCLIPConfig
97
+
98
+ def __init__(self, config):
99
+ super().__init__(config)
100
+
101
+ # Safe import of custom modules
102
+ try:
103
+ # to strictly load from the local library
104
+ pretrained_path: str | None = getattr(
105
+ config, "pretrained_path", None)
106
+ if pretrained_path is None:
107
+ raise ValueError(
108
+ "The config must have a 'pretrained_path' attribute pointing to the local repository path.")
109
+ else:
110
+ print(f"Pretrained path from config: {pretrained_path}")
111
+
112
+ model_module = safe_import_from_repo('model', pretrained_path)
113
+ self.CLIP = model_module.CLIP
114
+ self.CLIPVisionCfg = model_module.CLIPVisionCfg
115
+ self.CLIPTextCfg = model_module.CLIPTextCfg
116
+ except ImportError as e:
117
+ raise ImportError(
118
+ f"Could not import model components: {e}. Make sure all model files are in the repository.")
119
+
120
+ # Create vision and text configs
121
+ vision_cfg = self.CLIPVisionCfg(
122
+ layers=config.vision_config.get("num_hidden_layers", 12),
123
+ width=config.vision_config.get("hidden_size", 768),
124
+ patch_size=config.vision_config.get("patch_size", 16),
125
+ image_size=config.vision_config.get("image_size", 224),
126
+ )
127
+
128
+ text_cfg = self.CLIPTextCfg(
129
+ context_length=config.text_config.get(
130
+ "max_position_embeddings", 77),
131
+ vocab_size=config.text_config.get("vocab_size", 49408),
132
+ width=config.text_config.get("hidden_size", 512),
133
+ layers=config.text_config.get("num_hidden_layers", 12),
134
+ )
135
+
136
+ # Initialize your custom CLIP model
137
+ self.model = self.CLIP(
138
+ embed_dim=config.projection_dim,
139
+ vision_cfg=vision_cfg,
140
+ text_cfg=text_cfg,
141
+ )
142
+
143
+ # These will be set when loading the state dict
144
+ # 여기 statedict에서 load하면 않된다. configuration에 떡하니 있으면서 무슨 짓거리냐
145
+ self.neuron_dict = config.neuron_dict
146
+ if self.neuron_dict is None:
147
+ raise ValueError("neuron_dict must be provided in the config.")
148
+ self.num_register_tokens = config.num_register_tokens
149
+
150
+ # These will be loaded separately
151
+ self._tokenizer = None
152
+ self._preprocessor = None
153
+ self._zeroshot_classifier = None
154
+
155
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
156
+ """Override to handle custom parameters and load weights properly"""
157
+
158
+ # Extract custom parameters first
159
+ if 'neuron_dict' in state_dict:
160
+ self.neuron_dict = state_dict.pop('neuron_dict')
161
+
162
+ if 'num_register_tokens' in state_dict:
163
+ self.num_register_tokens = state_dict.pop('num_register_tokens')
164
+
165
+ # Set these values in the model
166
+ if hasattr(self.model, 'visual'):
167
+ self.model.visual.num_register_tokens = self.num_register_tokens
168
+ self.model.visual.neuron_dict = self.neuron_dict
169
+ self.model.num_register_tokens = self.num_register_tokens
170
+ self.model.neuron_dict = self.neuron_dict
171
+
172
+ # Load the weights properly - suppress ALL warnings and errors
173
+ with warnings.catch_warnings():
174
+ warnings.simplefilter("ignore")
175
+
176
+ # Temporarily set logging to critical only
177
+ original_level = logging.get_verbosity()
178
+ logging.set_verbosity_error()
179
+
180
+ try:
181
+ # Load weights directly into self.model
182
+ missing, unexpected = self.model.load_state_dict(
183
+ state_dict, strict=False)
184
+
185
+ # Don't report any missing/unexpected keys to avoid warnings
186
+
187
+ except Exception as e:
188
+ # If direct loading fails, try the parent method silently
189
+ super()._load_from_state_dict(state_dict, prefix, local_metadata, False, [], [], [])
190
+ finally:
191
+ # Restore logging level
192
+ logging.set_verbosity(original_level)
193
+
194
+ @classmethod
195
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
196
+ """Override to load cleanly and suppress warnings"""
197
+
198
+ # Suppress warnings during loading
199
+ with warnings.catch_warnings():
200
+ warnings.simplefilter("ignore")
201
+
202
+ # Temporarily suppress transformers logging
203
+ original_level = logging.get_verbosity()
204
+ logging.set_verbosity_error()
205
+
206
+ try:
207
+ # Load the model
208
+ model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
209
+ finally:
210
+ # Restore logging
211
+ logging.set_verbosity(original_level)
212
+
213
+ # Load additional components
214
+ model._load_additional_components(pretrained_model_name_or_path)
215
+
216
+ # Print clean success message
217
+ print("Custom CLIP model loaded successfully!")
218
+
219
+ return model
220
+
221
+ def _load_additional_components(self, pretrained_model_name_or_path):
222
+ """Load tokenizer, preprocessor, and zero-shot classifier silently"""
223
+
224
+ try:
225
+ from huggingface_hub import hf_hub_download
226
+
227
+ # Load tokenizer
228
+ try:
229
+ # Safe import of tokenizer
230
+ tokenizer_module = safe_import_from_repo(
231
+ 'tokenizer', Path(__file__).parent)
232
+ self._tokenizer = tokenizer_module.SimpleTokenizer()
233
+ except ImportError:
234
+ # If tokenizer import fails, create a dummy tokenizer message
235
+ pass
236
+
237
+ # Load preprocessor
238
+ try:
239
+ preprocess_config_file = hf_hub_download(
240
+ repo_id=pretrained_model_name_or_path,
241
+ filename="preprocessor_config.json"
242
+ )
243
+
244
+ with open(preprocess_config_file, 'r') as f:
245
+ preprocess_config = json.load(f)
246
+
247
+ self._create_preprocessor(preprocess_config)
248
+ except:
249
+ pass
250
+
251
+ # Load zero-shot classifier
252
+ try:
253
+ classifier_file = hf_hub_download(
254
+ repo_id=pretrained_model_name_or_path,
255
+ filename="zeroshot_classifier.pt"
256
+ )
257
+
258
+ # Suppress the torch.load warning
259
+ with warnings.catch_warnings():
260
+ warnings.simplefilter("ignore")
261
+ self._zeroshot_classifier = torch.load(
262
+ classifier_file, map_location='cpu', weights_only=False)
263
+ except:
264
+ pass
265
+
266
+ except:
267
+ pass
268
+
269
+ def _create_preprocessor(self, config):
270
+ """Create image preprocessor from config"""
271
+ try:
272
+ from torchvision import transforms
273
+
274
+ self._preprocessor = transforms.Compose([
275
+ transforms.Resize(
276
+ config["image_size"], interpolation=transforms.InterpolationMode.BICUBIC),
277
+ transforms.CenterCrop(config["image_size"]),
278
+ transforms.ToTensor(),
279
+ transforms.Normalize(
280
+ mean=config["image_mean"], std=config["image_std"]),
281
+ ])
282
+ except:
283
+ pass
284
+
285
+ @property
286
+ def tokenizer(self):
287
+ """Access the tokenizer"""
288
+ return self._tokenizer
289
+
290
+ @property
291
+ def preprocessor(self):
292
+ """Access the image preprocessor"""
293
+ return self._preprocessor
294
+
295
+ @property
296
+ def zeroshot_classifier(self):
297
+ """Access the zero-shot classifier"""
298
+ return self._zeroshot_classifier
299
+
300
+ def tokenize(self, texts, context_length=77):
301
+ """Tokenize text using the loaded tokenizer"""
302
+ if self._tokenizer is None:
303
+ raise ValueError(
304
+ "Tokenizer not available. Make sure tokenizer.py is in the repository.")
305
+
306
+ # Safe import of tokenize function
307
+ try:
308
+ tokenizer_module = safe_import_from_repo(
309
+ 'tokenizer', Path(__file__).parent)
310
+ return tokenizer_module.tokenize(texts, context_length)
311
+ except ImportError:
312
+ raise ValueError("Could not import tokenize function.")
313
+
314
+ def preprocess_image(self, image):
315
+ """Preprocess image using the loaded preprocessor"""
316
+ if self._preprocessor is None:
317
+ raise ValueError(
318
+ "Preprocessor not loaded. Make sure preprocessor_config.json is in the repository.")
319
+
320
+ return self._preprocessor(image)
321
+
322
+ def forward(self, input_ids=None, pixel_values=None, num_register_tokens=None, neuron_dict=None, **kwargs):
323
+ """Forward pass supporting your custom functionality"""
324
+
325
+ if num_register_tokens is None:
326
+ num_register_tokens = self.num_register_tokens
327
+ if neuron_dict is None:
328
+ neuron_dict = self.neuron_dict
329
+
330
+ return self.model(
331
+ image=pixel_values,
332
+ text=input_ids,
333
+ num_register_tokens=num_register_tokens,
334
+ neuron_dict=neuron_dict
335
+ )
336
+
337
+ def encode_image(self, pixel_values, num_register_tokens=None, neuron_dict=None, **kwargs):
338
+ """Encode images with register token support"""
339
+ if num_register_tokens is None:
340
+ num_register_tokens = self.num_register_tokens
341
+ if neuron_dict is None:
342
+ neuron_dict = self.neuron_dict
343
+
344
+ return self.model.encode_image(
345
+ pixel_values,
346
+ num_register_tokens=num_register_tokens,
347
+ neuron_dict=neuron_dict,
348
+ **kwargs
349
+ )
350
+
351
+ def encode_text(self, input_ids, **kwargs):
352
+ """Encode text"""
353
+ return self.model.encode_text(input_ids, **kwargs)
354
+
355
+
356
+ # Auto-suppress warnings at module level
357
+ transformers.logging.set_verbosity_error()
modified_resnet.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+ from misc import freeze_batch_norm_2d
8
+
9
+
10
+ class Bottleneck(nn.Module):
11
+ expansion = 4
12
+
13
+ def __init__(self, inplanes, planes, stride=1):
14
+ super().__init__()
15
+
16
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
17
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
18
+ self.bn1 = nn.BatchNorm2d(planes)
19
+ self.act1 = nn.ReLU(inplace=True)
20
+
21
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
22
+ self.bn2 = nn.BatchNorm2d(planes)
23
+ self.act2 = nn.ReLU(inplace=True)
24
+
25
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
26
+
27
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
28
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
29
+ self.act3 = nn.ReLU(inplace=True)
30
+
31
+ self.downsample = None
32
+ self.stride = stride
33
+
34
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
35
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
36
+ self.downsample = nn.Sequential(OrderedDict([
37
+ ("-1", nn.AvgPool2d(stride)),
38
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
39
+ ("1", nn.BatchNorm2d(planes * self.expansion))
40
+ ]))
41
+
42
+ def forward(self, x: torch.Tensor):
43
+ identity = x
44
+
45
+ out = self.act1(self.bn1(self.conv1(x)))
46
+ out = self.act2(self.bn2(self.conv2(out)))
47
+ out = self.avgpool(out)
48
+ out = self.bn3(self.conv3(out))
49
+
50
+ if self.downsample is not None:
51
+ identity = self.downsample(x)
52
+
53
+ out += identity
54
+ out = self.act3(out)
55
+ return out
56
+
57
+
58
+ class AttentionPool2d(nn.Module):
59
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
60
+ super().__init__()
61
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
62
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
63
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
64
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
65
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
66
+ self.num_heads = num_heads
67
+
68
+ def forward(self, x):
69
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
70
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
71
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
72
+ x, _ = F.multi_head_attention_forward(
73
+ query=x, key=x, value=x,
74
+ embed_dim_to_check=x.shape[-1],
75
+ num_heads=self.num_heads,
76
+ q_proj_weight=self.q_proj.weight,
77
+ k_proj_weight=self.k_proj.weight,
78
+ v_proj_weight=self.v_proj.weight,
79
+ in_proj_weight=None,
80
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
81
+ bias_k=None,
82
+ bias_v=None,
83
+ add_zero_attn=False,
84
+ dropout_p=0.,
85
+ out_proj_weight=self.c_proj.weight,
86
+ out_proj_bias=self.c_proj.bias,
87
+ use_separate_proj_weight=True,
88
+ training=self.training,
89
+ need_weights=False
90
+ )
91
+
92
+ return x[0]
93
+
94
+
95
+ class ModifiedResNet(nn.Module):
96
+ """
97
+ A ResNet class that is similar to torchvision's but contains the following changes:
98
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
99
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
100
+ - The final pooling layer is a QKV attention instead of an average pool
101
+ """
102
+
103
+ def __init__(self, layers, output_dim, heads, image_size=224, width=64):
104
+ super().__init__()
105
+ self.output_dim = output_dim
106
+ self.image_size = image_size
107
+
108
+ # the 3-layer stem
109
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
110
+ self.bn1 = nn.BatchNorm2d(width // 2)
111
+ self.act1 = nn.ReLU(inplace=True)
112
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
113
+ self.bn2 = nn.BatchNorm2d(width // 2)
114
+ self.act2 = nn.ReLU(inplace=True)
115
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
116
+ self.bn3 = nn.BatchNorm2d(width)
117
+ self.act3 = nn.ReLU(inplace=True)
118
+ self.avgpool = nn.AvgPool2d(2)
119
+
120
+ # residual layers
121
+ self._inplanes = width # this is a *mutable* variable used during construction
122
+ self.layer1 = self._make_layer(width, layers[0])
123
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
124
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
125
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
126
+
127
+ embed_dim = width * 32 # the ResNet feature dimension
128
+ self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
129
+
130
+ self.init_parameters()
131
+
132
+ def _make_layer(self, planes, blocks, stride=1):
133
+ layers = [Bottleneck(self._inplanes, planes, stride)]
134
+
135
+ self._inplanes = planes * Bottleneck.expansion
136
+ for _ in range(1, blocks):
137
+ layers.append(Bottleneck(self._inplanes, planes))
138
+
139
+ return nn.Sequential(*layers)
140
+
141
+ def init_parameters(self):
142
+ if self.attnpool is not None:
143
+ std = self.attnpool.c_proj.in_features ** -0.5
144
+ nn.init.normal_(self.attnpool.q_proj.weight, std=std)
145
+ nn.init.normal_(self.attnpool.k_proj.weight, std=std)
146
+ nn.init.normal_(self.attnpool.v_proj.weight, std=std)
147
+ nn.init.normal_(self.attnpool.c_proj.weight, std=std)
148
+
149
+ for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
150
+ for name, param in resnet_block.named_parameters():
151
+ if name.endswith("bn3.weight"):
152
+ nn.init.zeros_(param)
153
+
154
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
155
+ assert unlocked_groups == 0, 'partial locking not currently supported for this model'
156
+ for param in self.parameters():
157
+ param.requires_grad = False
158
+ if freeze_bn_stats:
159
+ freeze_batch_norm_2d(self)
160
+
161
+ @torch.jit.ignore
162
+ def set_grad_checkpointing(self, enable=True):
163
+ # FIXME support for non-transformer
164
+ pass
165
+
166
+ def stem(self, x):
167
+ x = self.act1(self.bn1(self.conv1(x)))
168
+ x = self.act2(self.bn2(self.conv2(x)))
169
+ x = self.act3(self.bn3(self.conv3(x)))
170
+ x = self.avgpool(x)
171
+ return x
172
+
173
+ def forward(self, x):
174
+ x = self.stem(x)
175
+ x = self.layer1(x)
176
+ x = self.layer2(x)
177
+ x = self.layer3(x)
178
+ x = self.layer4(x)
179
+ x = self.attnpool(x)
180
+
181
+ return x
neuron_indices.json ADDED
@@ -0,0 +1 @@
 
 
1
+ [[12, 42, 39.99140930175781], [12, 983, 34.50058364868164], [12, 3868, 23.993741989135742], [12, 2687, 23.192779541015625], [11, 3784, 14.847213745117188], [11, 987, 14.675474166870117], [11, 3661, 14.301347732543945], [12, 3008, 12.25265884399414], [11, 1967, 11.993508338928223], [12, 3002, 10.681584358215332], [11, 9, 9.61478042602539], [21, 1801, 8.448626518249512], [11, 2555, 6.903197288513184], [11, 1100, 6.859874725341797], [12, 1571, 4.70828104019165], [22, 901, 3.453416109085083], [21, 1550, 3.4134912490844727], [12, 1816, 3.37734055519104], [12, 183, 3.1418349742889404], [8, 745, 3.1221530437469482], [9, 4078, 3.0656824111938477], [9, 815, 3.0607407093048096], [10, 357, 2.7818374633789062], [9, 3618, 2.690423011779785], [10, 1654, 2.6796107292175293], [22, 2184, 2.6561291217803955], [10, 3940, 2.652881383895874], [7, 3228, 2.46209979057312], [10, 2319, 2.308473825454712], [9, 2693, 2.1979129314422607], [21, 1779, 2.1429498195648193], [20, 3077, 2.1137425899505615], [20, 2634, 2.04282808303833], [9, 3973, 2.031193733215332], [21, 3137, 2.026745080947876], [8, 3249, 1.9856672286987305], [8, 2585, 1.9620095491409302], [9, 1983, 1.9459211826324463], [9, 1744, 1.9378128051757812], [9, 1157, 1.749971866607666], [21, 2412, 1.7358660697937012], [10, 2560, 1.6931447982788086], [7, 2550, 1.6547895669937134], [21, 1381, 1.5941085815429688], [22, 1317, 1.560852289199829], [8, 1537, 1.5494486093521118], [8, 200, 1.4573794603347778], [19, 1881, 1.4518368244171143], [8, 1603, 1.416003704071045], [8, 1851, 1.3301061391830444], [8, 3523, 1.321004867553711], [12, 2780, 1.2789242267608643], [13, 1109, 1.2571412324905396], [10, 2559, 1.2549676895141602], [9, 1309, 1.238487958908081], [21, 2193, 1.2044764757156372], [17, 1868, 1.1777989864349365], [21, 1796, 1.1429805755615234], [10, 4009, 1.0898690223693848], [9, 1335, 1.0648274421691895], [22, 2889, 1.0604228973388672], [11, 888, 1.0271515846252441], [15, 415, 1.0207806825637817], [21, 68, 1.0149273872375488], [9, 3049, 0.9941853880882263], [9, 2607, 0.9631124138832092], [9, 2621, 0.954177737236023], [18, 1283, 0.9397207498550415], [9, 2396, 0.9153910875320435], [22, 797, 0.8976885676383972], [12, 2370, 0.8916781544685364], [22, 3026, 0.8911128044128418], [10, 3029, 0.864679217338562], [19, 2881, 0.8607441782951355], [9, 1610, 0.8600460886955261], [22, 3143, 0.8545671105384827], [19, 1149, 0.8446468114852905], [22, 806, 0.8359670042991638], [20, 676, 0.8346958756446838], [18, 3018, 0.8332728147506714], [22, 2714, 0.8295024037361145], [9, 2867, 0.813927948474884], [22, 3888, 0.8113453388214111], [8, 3697, 0.8057175278663635], [22, 1832, 0.7937105894088745], [22, 985, 0.7906701564788818], [22, 3361, 0.783061683177948], [9, 2394, 0.7818043231964111], [22, 3049, 0.7765958309173584], [8, 3137, 0.772114098072052], [10, 951, 0.7676676511764526], [11, 3568, 0.7665989398956299], [8, 2563, 0.7626394629478455], [23, 1137, 0.7513239979743958], [17, 604, 0.7489021420478821], [9, 1924, 0.7423470616340637], [19, 2106, 0.7369621992111206], [9, 2012, 0.7241123914718628], [10, 1903, 0.7238287329673767], [12, 3574, 0.7192695736885071]]
openai_models.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ OpenAI pretrained model functions
2
+
3
+ Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
+ """
5
+
6
+ import os
7
+ import warnings
8
+ from typing import List, Optional, Union
9
+
10
+ import torch
11
+
12
+ from constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
13
+ from model import build_model_from_openai_state_dict, get_cast_dtype
14
+ from pretrained import *
15
+
16
+ __all__ = ["list_openai_models", "load_openai_model"]
17
+
18
+
19
+ def list_openai_models() -> List[str]:
20
+ """Returns the names of available CLIP models"""
21
+ return list_pretrained_models_by_tag('openai')
22
+
23
+
24
+ def load_openai_model(
25
+ name: str,
26
+ precision: Optional[str] = None,
27
+ device: Optional[Union[str, torch.device]] = None,
28
+ cache_dir: Optional[str] = None,
29
+ quick_gelu: Optional[bool] = True
30
+ ):
31
+ """Load a CLIP model
32
+
33
+ Parameters
34
+ ----------
35
+ name : str
36
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
37
+ precision: str
38
+ Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'.
39
+ device : Union[str, torch.device]
40
+ The device to put the loaded model
41
+ cache_dir : Optional[str]
42
+ The directory to cache the downloaded model weights
43
+
44
+ Returns
45
+ -------
46
+ model : torch.nn.Module
47
+ The CLIP model
48
+ preprocess : Callable[[PIL.Image], torch.Tensor]
49
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
50
+ """
51
+ if device is None:
52
+ device = "cuda" if torch.cuda.is_available() else "cpu"
53
+ if precision is None:
54
+ precision = 'fp32' if device == 'cpu' else 'fp16'
55
+
56
+ if get_pretrained_url(name, 'openai'):
57
+ model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir)
58
+ elif os.path.isfile(name):
59
+ model_path = name
60
+ else:
61
+ raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}")
62
+
63
+ try:
64
+ # loading JIT archive
65
+ model = torch.jit.load(model_path, map_location="cpu").eval()
66
+ state_dict = None
67
+ except RuntimeError:
68
+ # loading saved state dict
69
+ state_dict = torch.load(model_path, map_location="cpu")
70
+
71
+ # Build a non-jit model from the OpenAI jitted model state dict
72
+ cast_dtype = get_cast_dtype(precision)
73
+ try:
74
+ model = build_model_from_openai_state_dict(state_dict or model.state_dict(), quick_gelu=quick_gelu, cast_dtype=cast_dtype)
75
+ except KeyError:
76
+ sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
77
+ model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype, quick_gelu=quick_gelu)
78
+
79
+ # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use
80
+ model = model.to(device)
81
+ # FIXME support pure fp16/bf16 precision modes
82
+ if precision != 'fp16':
83
+ model.float()
84
+ if precision == 'bf16':
85
+ # for bf16, convert back to low-precision
86
+ convert_weights_to_lp(model, dtype=torch.bfloat16)
87
+
88
+ # add mean / std attributes for consistency with OpenCLIP models
89
+ model.visual.image_mean = OPENAI_DATASET_MEAN
90
+ model.visual.image_std = OPENAI_DATASET_STD
91
+ return model
openai_templates.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ OPENAI_IMAGENET_TEMPLATES = (
3
+ lambda c: f'a bad photo of a {c}.',
4
+ lambda c: f'a photo of many {c}.',
5
+ lambda c: f'a sculpture of a {c}.',
6
+ lambda c: f'a photo of the hard to see {c}.',
7
+ lambda c: f'a low resolution photo of the {c}.',
8
+ lambda c: f'a rendering of a {c}.',
9
+ lambda c: f'graffiti of a {c}.',
10
+ lambda c: f'a bad photo of the {c}.',
11
+ lambda c: f'a cropped photo of the {c}.',
12
+ lambda c: f'a tattoo of a {c}.',
13
+ lambda c: f'the embroidered {c}.',
14
+ lambda c: f'a photo of a hard to see {c}.',
15
+ lambda c: f'a bright photo of a {c}.',
16
+ lambda c: f'a photo of a clean {c}.',
17
+ lambda c: f'a photo of a dirty {c}.',
18
+ lambda c: f'a dark photo of the {c}.',
19
+ lambda c: f'a drawing of a {c}.',
20
+ lambda c: f'a photo of my {c}.',
21
+ lambda c: f'the plastic {c}.',
22
+ lambda c: f'a photo of the cool {c}.',
23
+ lambda c: f'a close-up photo of a {c}.',
24
+ lambda c: f'a black and white photo of the {c}.',
25
+ lambda c: f'a painting of the {c}.',
26
+ lambda c: f'a painting of a {c}.',
27
+ lambda c: f'a pixelated photo of the {c}.',
28
+ lambda c: f'a sculpture of the {c}.',
29
+ lambda c: f'a bright photo of the {c}.',
30
+ lambda c: f'a cropped photo of a {c}.',
31
+ lambda c: f'a plastic {c}.',
32
+ lambda c: f'a photo of the dirty {c}.',
33
+ lambda c: f'a jpeg corrupted photo of a {c}.',
34
+ lambda c: f'a blurry photo of the {c}.',
35
+ lambda c: f'a photo of the {c}.',
36
+ lambda c: f'a good photo of the {c}.',
37
+ lambda c: f'a rendering of the {c}.',
38
+ lambda c: f'a {c} in a video game.',
39
+ lambda c: f'a photo of one {c}.',
40
+ lambda c: f'a doodle of a {c}.',
41
+ lambda c: f'a close-up photo of the {c}.',
42
+ lambda c: f'a photo of a {c}.',
43
+ lambda c: f'the origami {c}.',
44
+ lambda c: f'the {c} in a video game.',
45
+ lambda c: f'a sketch of a {c}.',
46
+ lambda c: f'a doodle of the {c}.',
47
+ lambda c: f'a origami {c}.',
48
+ lambda c: f'a low resolution photo of a {c}.',
49
+ lambda c: f'the toy {c}.',
50
+ lambda c: f'a rendition of the {c}.',
51
+ lambda c: f'a photo of the clean {c}.',
52
+ lambda c: f'a photo of a large {c}.',
53
+ lambda c: f'a rendition of a {c}.',
54
+ lambda c: f'a photo of a nice {c}.',
55
+ lambda c: f'a photo of a weird {c}.',
56
+ lambda c: f'a blurry photo of a {c}.',
57
+ lambda c: f'a cartoon {c}.',
58
+ lambda c: f'art of a {c}.',
59
+ lambda c: f'a sketch of the {c}.',
60
+ lambda c: f'a embroidered {c}.',
61
+ lambda c: f'a pixelated photo of a {c}.',
62
+ lambda c: f'itap of the {c}.',
63
+ lambda c: f'a jpeg corrupted photo of the {c}.',
64
+ lambda c: f'a good photo of a {c}.',
65
+ lambda c: f'a plushie {c}.',
66
+ lambda c: f'a photo of the nice {c}.',
67
+ lambda c: f'a photo of the small {c}.',
68
+ lambda c: f'a photo of the weird {c}.',
69
+ lambda c: f'the cartoon {c}.',
70
+ lambda c: f'art of the {c}.',
71
+ lambda c: f'a drawing of the {c}.',
72
+ lambda c: f'a photo of the large {c}.',
73
+ lambda c: f'a black and white photo of a {c}.',
74
+ lambda c: f'the plushie {c}.',
75
+ lambda c: f'a dark photo of a {c}.',
76
+ lambda c: f'itap of a {c}.',
77
+ lambda c: f'graffiti of the {c}.',
78
+ lambda c: f'a toy {c}.',
79
+ lambda c: f'itap of my {c}.',
80
+ lambda c: f'a photo of a cool {c}.',
81
+ lambda c: f'a photo of a small {c}.',
82
+ lambda c: f'a tattoo of the {c}.',
83
+ )
84
+
preprocess.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Imports
2
+ import numpy as np
3
+ import torch
4
+ from PIL import Image
5
+ import os.path
6
+ import argparse
7
+ from pathlib import Path
8
+ import cv2
9
+ from torch.nn import functional as F
10
+ from torch.utils.data import DataLoader
11
+ import tqdm
12
+ import einops
13
+ import plotly.express as px
14
+ import torch.nn.functional as F
15
+ import tqdm
16
+ import json
17
+ import albumentations
18
+ import glob
19
+ from torchvision import transforms
20
+
21
+
22
+ def _convert_to_rgb(image):
23
+ return image.convert('RGB')
24
+
25
+ def _resize(image):
26
+ image = np.array(image)
27
+ image = albumentations.augmentations.geometric.resize.LongestMaxSize(interpolation=Image.BICUBIC,
28
+ max_size=224)(image=image)
29
+ return Image.fromarray(image['image'])
30
+
31
+ preprocess = transforms.Compose([
32
+ _resize,
33
+ transforms.CenterCrop(size=(224, 224)),
34
+ _convert_to_rgb,
35
+ ])
36
+
37
+
38
+ both_preprocess = transforms.Compose([
39
+ transforms.ToTensor(),
40
+ transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
41
+ std=(0.26862954, 0.26130258, 0.27577711)),
42
+ ])
preprocessor_config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "crop_size": 336,
3
+ "do_center_crop": true,
4
+ "do_normalize": true,
5
+ "do_resize": true,
6
+ "feature_extractor_type": "CLIPFeatureExtractor",
7
+ "image_mean": [
8
+ 0.48145466,
9
+ 0.4578275,
10
+ 0.40821073
11
+ ],
12
+ "image_std": [
13
+ 0.26862954,
14
+ 0.26130258,
15
+ 0.27577711
16
+ ],
17
+ "resample": 3,
18
+ "size": 336
19
+ }
preprocessor_config_bak.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "image_size": 224,
3
+ "image_mean": [
4
+ 0.48145466,
5
+ 0.4578275,
6
+ 0.40821073
7
+ ],
8
+ "image_std": [
9
+ 0.26862954,
10
+ 0.26130258,
11
+ 0.27577711
12
+ ],
13
+ "interpolation": "bicubic",
14
+ "resize_mode": "center_crop"
15
+ }
16
+
pretrained.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+ import urllib
4
+ import warnings
5
+ from functools import partial
6
+ from typing import Dict, Union
7
+
8
+ from tqdm import tqdm
9
+
10
+
11
+ try:
12
+ from huggingface_hub import hf_hub_download
13
+ hf_hub_download = partial(hf_hub_download, library_name="open_clip", library_version='2.20.0')
14
+ _has_hf_hub = True
15
+ except ImportError:
16
+ hf_hub_download = None
17
+ _has_hf_hub = False
18
+
19
+
20
+ def _pcfg(url='', hf_hub='', mean=None, std=None):
21
+ return dict(
22
+ url=url,
23
+ hf_hub=hf_hub,
24
+ mean=mean,
25
+ std=std,
26
+ )
27
+
28
+
29
+ _RN50 = dict(
30
+ openai=_pcfg(
31
+ "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"),
32
+ yfcc15m=_pcfg(
33
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"),
34
+ cc12m=_pcfg(
35
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"),
36
+ )
37
+
38
+ _RN50_quickgelu = dict(
39
+ openai=_pcfg(
40
+ "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"),
41
+ yfcc15m=_pcfg(
42
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"),
43
+ cc12m=_pcfg(
44
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"),
45
+ )
46
+
47
+ _RN101 = dict(
48
+ openai=_pcfg(
49
+ "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"),
50
+ yfcc15m=_pcfg(
51
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"),
52
+ )
53
+
54
+ _RN101_quickgelu = dict(
55
+ openai=_pcfg(
56
+ "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"),
57
+ yfcc15m=_pcfg(
58
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"),
59
+ )
60
+
61
+ _RN50x4 = dict(
62
+ openai=_pcfg(
63
+ "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt"),
64
+ )
65
+
66
+ _RN50x16 = dict(
67
+ openai=_pcfg(
68
+ "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt"),
69
+ )
70
+
71
+ _RN50x64 = dict(
72
+ openai=_pcfg(
73
+ "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt"),
74
+ )
75
+
76
+ _VITB32 = dict(
77
+ openai=_pcfg(
78
+ "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
79
+ laion400m_e31=_pcfg(
80
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
81
+ laion400m_e32=_pcfg(
82
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
83
+ laion2b_e16=_pcfg(
84
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"),
85
+ laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/'),
86
+ # DataComp-M models
87
+ datacomp_m_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.M-s128M-b4K/'),
88
+ commonpool_m_clip_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.clip-s128M-b4K/'),
89
+ commonpool_m_laion_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.laion-s128M-b4K/'),
90
+ commonpool_m_image_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.image-s128M-b4K/'),
91
+ commonpool_m_text_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.text-s128M-b4K/'),
92
+ commonpool_m_basic_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.basic-s128M-b4K/'),
93
+ commonpool_m_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M-s128M-b4K/'),
94
+ # DataComp-S models
95
+ datacomp_s_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.S-s13M-b4K/'),
96
+ commonpool_s_clip_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.clip-s13M-b4K/'),
97
+ commonpool_s_laion_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.laion-s13M-b4K/'),
98
+ commonpool_s_image_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.image-s13M-b4K/'),
99
+ commonpool_s_text_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.text-s13M-b4K/'),
100
+ commonpool_s_basic_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.basic-s13M-b4K/'),
101
+ commonpool_s_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S-s13M-b4K/'),
102
+ )
103
+
104
+ _VITB32_quickgelu = dict(
105
+ openai=_pcfg(
106
+ "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
107
+ laion400m_e31=_pcfg(
108
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
109
+ laion400m_e32=_pcfg(
110
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
111
+ )
112
+
113
+ _VITB16 = dict(
114
+ openai=_pcfg(
115
+ "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"),
116
+ laion400m_e31=_pcfg(
117
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"),
118
+ laion400m_e32=_pcfg(
119
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"),
120
+ laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'),
121
+ # DataComp-L models
122
+ datacomp_l_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-DataComp.L-s1B-b8K/'),
123
+ commonpool_l_clip_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.clip-s1B-b8K/'),
124
+ commonpool_l_laion_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.laion-s1B-b8K/'),
125
+ commonpool_l_image_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.image-s1B-b8K/'),
126
+ commonpool_l_text_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.text-s1B-b8K/'),
127
+ commonpool_l_basic_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.basic-s1B-b8K/'),
128
+ commonpool_l_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L-s1B-b8K/'),
129
+ )
130
+
131
+ _VITB16_PLUS_240 = dict(
132
+ laion400m_e31=_pcfg(
133
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"),
134
+ laion400m_e32=_pcfg(
135
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"),
136
+ )
137
+
138
+ _VITL14 = dict(
139
+ openai=_pcfg(
140
+ "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"),
141
+ laion400m_e31=_pcfg(
142
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"),
143
+ laion400m_e32=_pcfg(
144
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"),
145
+ laion2b_s32b_b82k=_pcfg(
146
+ hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/',
147
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
148
+ # DataComp-XL models
149
+ datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K/'),
150
+ commonpool_xl_clip_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL.clip-s13B-b90K/'),
151
+ commonpool_xl_laion_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL.laion-s13B-b90K/'),
152
+ commonpool_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL-s13B-b90K/'),
153
+ )
154
+
155
+ _VITL14_336 = dict(
156
+ openai=_pcfg(
157
+ "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"),
158
+ )
159
+
160
+ _VITH14 = dict(
161
+ laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'),
162
+ )
163
+
164
+ _VITg14 = dict(
165
+ laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'),
166
+ laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'),
167
+ )
168
+
169
+ _VITbigG14 = dict(
170
+ laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'),
171
+ )
172
+
173
+ _robertaViTB32 = dict(
174
+ laion2b_s12b_b32k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-roberta-base-laion2B-s12B-b32k/'),
175
+ )
176
+
177
+ _xlmRobertaBaseViTB32 = dict(
178
+ laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-xlm-roberta-base-laion5B-s13B-b90k/'),
179
+ )
180
+
181
+ _xlmRobertaLargeFrozenViTH14 = dict(
182
+ frozen_laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k/'),
183
+ )
184
+
185
+ _convnext_base = dict(
186
+ laion400m_s13b_b51k=_pcfg(hf_hub='laion/CLIP-convnext_base-laion400M-s13B-b51K/'),
187
+ )
188
+
189
+ _convnext_base_w = dict(
190
+ laion2b_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K/'),
191
+ laion2b_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg/'),
192
+ laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K/'),
193
+ )
194
+
195
+ _convnext_base_w_320 = dict(
196
+ laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K/'),
197
+ laion_aesthetic_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg/'),
198
+ )
199
+
200
+ _convnext_large_d = dict(
201
+ laion2b_s26b_b102k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_large_d.laion2B-s26B-b102K-augreg/'),
202
+ )
203
+
204
+ _convnext_large_d_320 = dict(
205
+ laion2b_s29b_b131k_ft=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft/'),
206
+ laion2b_s29b_b131k_ft_soup=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup/'),
207
+ )
208
+
209
+ _convnext_xxlarge = dict(
210
+ laion2b_s34b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg/'),
211
+ laion2b_s34b_b82k_augreg_rewind=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-rewind/'),
212
+ laion2b_s34b_b82k_augreg_soup=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup/'),
213
+ )
214
+
215
+ _coca_VITB32 = dict(
216
+ laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-B-32-laion2B-s13B-b90k/'),
217
+ mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-B-32-laion2B-s13B-b90k/')
218
+ )
219
+
220
+ _coca_VITL14 = dict(
221
+ laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-L-14-laion2B-s13B-b90k/'),
222
+ mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-L-14-laion2B-s13B-b90k/')
223
+ )
224
+
225
+
226
+ _PRETRAINED = {
227
+ "RN50": _RN50,
228
+ "RN50-quickgelu": _RN50_quickgelu,
229
+ "RN101": _RN101,
230
+ "RN101-quickgelu": _RN101_quickgelu,
231
+ "RN50x4": _RN50x4,
232
+ "RN50x16": _RN50x16,
233
+ "RN50x64": _RN50x64,
234
+ "ViT-B-32": _VITB32,
235
+ "ViT-B-32-quickgelu": _VITB32_quickgelu,
236
+ "ViT-B-16": _VITB16,
237
+ "ViT-B-16-plus-240": _VITB16_PLUS_240,
238
+ "ViT-L-14": _VITL14,
239
+ "ViT-L-14-336": _VITL14_336,
240
+ "ViT-H-14": _VITH14,
241
+ "ViT-g-14": _VITg14,
242
+ "ViT-bigG-14": _VITbigG14,
243
+ "roberta-ViT-B-32": _robertaViTB32,
244
+ "xlm-roberta-base-ViT-B-32": _xlmRobertaBaseViTB32,
245
+ "xlm-roberta-large-ViT-H-14": _xlmRobertaLargeFrozenViTH14,
246
+ "convnext_base": _convnext_base,
247
+ "convnext_base_w": _convnext_base_w,
248
+ "convnext_base_w_320": _convnext_base_w_320,
249
+ "convnext_large_d": _convnext_large_d,
250
+ "convnext_large_d_320": _convnext_large_d_320,
251
+ "convnext_xxlarge": _convnext_xxlarge,
252
+ "coca_ViT-B-32": _coca_VITB32,
253
+ "coca_ViT-L-14": _coca_VITL14,
254
+ "EVA01-g-14": dict(
255
+ # from QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt
256
+ laion400m_s11b_b41k=_pcfg(hf_hub='timm/eva_giant_patch14_clip_224.laion400m_s11b_b41k/'),
257
+ ),
258
+ "EVA01-g-14-plus": dict(
259
+ # from QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt
260
+ merged2b_s11b_b114k=_pcfg(hf_hub='timm/eva_giant_patch14_plus_clip_224.merged2b_s11b_b114k/'),
261
+ ),
262
+ "EVA02-B-16": dict(
263
+ # from QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt
264
+ merged2b_s8b_b131k=_pcfg(hf_hub='timm/eva02_base_patch16_clip_224.merged2b_s8b_b131k/'),
265
+ ),
266
+ "EVA02-L-14": dict(
267
+ # from QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt
268
+ merged2b_s4b_b131k=_pcfg(hf_hub='timm/eva02_large_patch14_clip_224.merged2b_s4b_b131k/'),
269
+ ),
270
+ "EVA02-L-14-336": dict(
271
+ # from QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt
272
+ merged2b_s6b_b61k=_pcfg(hf_hub='timm/eva02_large_patch14_clip_336.merged2b_s6b_b61k/'),
273
+ ),
274
+ "EVA02-E-14": dict(
275
+ # from QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt
276
+ laion2b_s4b_b115k=_pcfg(hf_hub='timm/eva02_enormous_patch14_clip_224.laion2b_s4b_b115k/'),
277
+ ),
278
+ "EVA02-E-14-plus": dict(
279
+ # from QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt
280
+ laion2b_s9b_b144k=_pcfg(hf_hub='timm/eva02_enormous_patch14_plus_clip_224.laion2b_s9b_b144k/'),
281
+ )
282
+ }
283
+
284
+
285
+ def _clean_tag(tag: str):
286
+ # normalize pretrained tags
287
+ return tag.lower().replace('-', '_')
288
+
289
+
290
+ def list_pretrained(as_str: bool = False):
291
+ """ returns list of pretrained models
292
+ Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True
293
+ """
294
+ return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()]
295
+
296
+
297
+ def list_pretrained_models_by_tag(tag: str):
298
+ """ return all models having the specified pretrain tag """
299
+ models = []
300
+ tag = _clean_tag(tag)
301
+ for k in _PRETRAINED.keys():
302
+ if tag in _PRETRAINED[k]:
303
+ models.append(k)
304
+ return models
305
+
306
+
307
+ def list_pretrained_tags_by_model(model: str):
308
+ """ return all pretrain tags for the specified model architecture """
309
+ tags = []
310
+ if model in _PRETRAINED:
311
+ tags.extend(_PRETRAINED[model].keys())
312
+ return tags
313
+
314
+
315
+ def is_pretrained_cfg(model: str, tag: str):
316
+ if model not in _PRETRAINED:
317
+ return False
318
+ return _clean_tag(tag) in _PRETRAINED[model]
319
+
320
+
321
+ def get_pretrained_cfg(model: str, tag: str):
322
+ if model not in _PRETRAINED:
323
+ return {}
324
+ model_pretrained = _PRETRAINED[model]
325
+ return model_pretrained.get(_clean_tag(tag), {})
326
+
327
+
328
+ def get_pretrained_url(model: str, tag: str):
329
+ cfg = get_pretrained_cfg(model, _clean_tag(tag))
330
+ return cfg.get('url', '')
331
+
332
+
333
+ def download_pretrained_from_url(
334
+ url: str,
335
+ cache_dir: Union[str, None] = None,
336
+ ):
337
+ if not cache_dir:
338
+ cache_dir = os.path.expanduser("~/.cache/clip")
339
+ os.makedirs(cache_dir, exist_ok=True)
340
+ filename = os.path.basename(url)
341
+
342
+ if 'openaipublic' in url:
343
+ expected_sha256 = url.split("/")[-2]
344
+ elif 'mlfoundations' in url:
345
+ expected_sha256 = os.path.splitext(filename)[0].split("-")[-1]
346
+ else:
347
+ expected_sha256 = ''
348
+
349
+ download_target = os.path.join(cache_dir, filename)
350
+
351
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
352
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
353
+
354
+ if os.path.isfile(download_target):
355
+ if expected_sha256:
356
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
357
+ return download_target
358
+ else:
359
+ warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
360
+ else:
361
+ return download_target
362
+
363
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
364
+ with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
365
+ while True:
366
+ buffer = source.read(8192)
367
+ if not buffer:
368
+ break
369
+
370
+ output.write(buffer)
371
+ loop.update(len(buffer))
372
+
373
+ if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
374
+ raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
375
+
376
+ return download_target
377
+
378
+
379
+ def has_hf_hub(necessary=False):
380
+ if not _has_hf_hub and necessary:
381
+ # if no HF Hub module installed, and it is necessary to continue, raise error
382
+ raise RuntimeError(
383
+ 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
384
+ return _has_hf_hub
385
+
386
+
387
+ def download_pretrained_from_hf(
388
+ model_id: str,
389
+ filename: str = 'open_clip_pytorch_model.bin',
390
+ revision=None,
391
+ cache_dir: Union[str, None] = None,
392
+ ):
393
+ has_hf_hub(True)
394
+ cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir)
395
+ return cached_file
396
+
397
+
398
+ def download_pretrained(
399
+ cfg: Dict,
400
+ force_hf_hub: bool = False,
401
+ cache_dir: Union[str, None] = None,
402
+ ):
403
+ target = ''
404
+ if not cfg:
405
+ return target
406
+
407
+ download_url = cfg.get('url', '')
408
+ download_hf_hub = cfg.get('hf_hub', '')
409
+ if download_hf_hub and force_hf_hub:
410
+ # use HF hub even if url exists
411
+ download_url = ''
412
+
413
+ if download_url:
414
+ target = download_pretrained_from_url(download_url, cache_dir=cache_dir)
415
+ elif download_hf_hub:
416
+ has_hf_hub(True)
417
+ # we assume the hf_hub entries in pretrained config combine model_id + filename in
418
+ # 'org/model_name/filename.pt' form. To specify just the model id w/o filename and
419
+ # use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'.
420
+ model_id, filename = os.path.split(download_hf_hub)
421
+ if filename:
422
+ target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir)
423
+ else:
424
+ target = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
425
+
426
+ return target
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c7a812f61be88b4148e6c910ea245178ff3663263d54680cdb99dd6bcaed9b32
3
+ size 1711950230
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch>=1.9.0
2
+ transformers>=4.21.0
3
+ torchvision>=0.10.0
4
+ Pillow
5
+ numpy
shared.py ADDED
@@ -0,0 +1,616 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import re
4
+ import random
5
+ import matplotlib.pyplot as plt
6
+ import json
7
+ def get_gpu_memory_usage():
8
+ """Returns a list of GPU memory usage in MB."""
9
+ try:
10
+ # Run nvidia-smi command and capture the output
11
+ result = subprocess.run(['nvidia-smi', '--query-gpu=memory.used', '--format=csv,nounits,noheader'],
12
+ stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
13
+
14
+ # Check if the command was successful
15
+ if result.returncode != 0:
16
+ raise RuntimeError(f"nvidia-smi command failed with error: {result.stderr}")
17
+
18
+ # Parse the output to get memory usage
19
+ memory_usages = [int(x) for x in result.stdout.strip().split('\n')]
20
+ return memory_usages
21
+ except Exception as e:
22
+ print(f"Error querying GPU memory usage: {e}")
23
+ return []
24
+
25
+ def set_cuda_visible_device():
26
+ """Sets the CUDA_VISIBLE_DEVICES environment variable to the GPU with the smallest memory usage."""
27
+ memory_usages = get_gpu_memory_usage()
28
+
29
+ if not memory_usages:
30
+ print("No GPU memory usage data available.")
31
+ return
32
+
33
+ # Find the index of the GPU with the smallest memory usage
34
+ min_memory_index = memory_usages.index(min(memory_usages))
35
+
36
+ # Set the CUDA_VISIBLE_DEVICES environment variable
37
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(min_memory_index)
38
+ print(f"Set CUDA_VISIBLE_DEVICES to GPU {min_memory_index} with {memory_usages[min_memory_index]} MB used.")
39
+
40
+ return str(min_memory_index)
41
+
42
+ os.environ["ASN_ROOT_DIR"] = "/home/nickj/asn/second_order_lens"
43
+ os.chdir(os.environ["ASN_ROOT_DIR"])
44
+
45
+ import numpy as np
46
+ import torch
47
+ from PIL import Image
48
+ import os.path
49
+ import argparse
50
+ from pathlib import Path
51
+
52
+ from tqdm import tqdm
53
+ from utils.factory import create_model_and_transforms, get_tokenizer
54
+ from PIL import Image, ImageDraw
55
+
56
+ def get_model(model_name = "ViT-B/16", pretrained = "openai", device = "cuda:0"):
57
+ torch.multiprocessing.set_sharing_strategy("file_system")
58
+ model, _, preprocess = create_model_and_transforms(
59
+ model_name, pretrained=pretrained, force_quick_gelu=True,
60
+ )
61
+ model.to(device)
62
+ model.eval()
63
+ context_length = model.context_length
64
+ vocab_size = model.vocab_size
65
+
66
+ return {
67
+ "model": model,
68
+ "model_name": model_name,
69
+ "pretrained": pretrained,
70
+ "preprocess": preprocess,
71
+ "context_length": context_length,
72
+ "vocab_size": vocab_size
73
+ }
74
+
75
+ img_path = "/datasets/ilsvrc_2024-01-04_1913/val/n04398044/ILSVRC2012_val_00042447.JPEG"
76
+ # img_path = "./sample.jpeg"
77
+ def load_images(preprocess, image_folder = "/datasets/ilsvrc/current/val", count = 100, images_only = True):
78
+ file_list = []
79
+
80
+ for root, dirs, files in os.walk(image_folder):
81
+ for file in files:
82
+ file_list.append(os.path.join(root, file))
83
+
84
+ if count > len(file_list):
85
+ sampled_files = file_list
86
+ else:
87
+ sampled_files = random.sample(file_list, count)
88
+
89
+ image_files = []
90
+
91
+ for filename in sampled_files:
92
+ image_files.append(preprocess(Image.open(filename)))
93
+ if images_only:
94
+ return image_files
95
+ else:
96
+ return image_files, sampled_files
97
+
98
+ def calc_neuron_potentials(model, attn_layers = (1, 2), include_layernorm = True):
99
+ # Calculates the attention-shifting potential scores for every neuron to the attention heads defined by the given layers (relative to the MLP layer)
100
+
101
+ embed_dim = model.visual.transformer.resblocks[0].attn.out_proj.in_features
102
+ num_heads = model.visual.transformer.resblocks[0].attn.num_heads
103
+ head_dim = embed_dim // num_heads
104
+ layers = len(model.visual.transformer.resblocks)
105
+
106
+ results = dict()
107
+
108
+ for neuron_layer in tqdm(range(layers), desc = "Calculating attention shifting potentials"):
109
+ neuron_projection = model.visual.transformer.resblocks[neuron_layer].state_dict()["mlp.c_proj.weight"]
110
+ for l_attn in range(min(layers, neuron_layer + attn_layers[0]), min(layers, neuron_layer + attn_layers[1])):
111
+ ln_vector = model.visual.transformer.resblocks[l_attn].ln_1.state_dict()["weight"]
112
+ attn_matrix = model.visual.transformer.resblocks[l_attn].state_dict()["attn.in_proj_weight"]
113
+ W_Q, W_K, W_V = (attn_matrix[:embed_dim].reshape(num_heads, head_dim, -1),
114
+ attn_matrix[embed_dim:2*embed_dim].reshape(num_heads, head_dim, -1),
115
+ attn_matrix[2*embed_dim:].reshape(num_heads, head_dim, -1))
116
+
117
+ for head_idx in range(num_heads):
118
+ W_Q_h, W_K_h = W_Q[head_idx], W_K[head_idx]
119
+ effects = []
120
+ for i in range(neuron_projection.shape[1]):
121
+ if include_layernorm:
122
+ neuron_attn_effect = torch.norm(W_Q_h.T @ W_K_h @ (neuron_projection[:, i] * ln_vector))
123
+ else:
124
+ neuron_attn_effect = torch.norm(W_Q_h.T @ W_K_h @ neuron_projection[:, i])
125
+ effects.append(neuron_attn_effect)
126
+
127
+ results[(neuron_layer, l_attn, head_idx)] = torch.tensor(effects)
128
+ return results
129
+
130
+ def calc_top_asns(shift_potentials, top_k = 10, per = "layer", layers_away = 1):
131
+ num_layers = max([key[1] for key in shift_potentials.keys()]) + 1 # the last layer has no ASNs by definition
132
+ num_heads = max([key[2] for key in shift_potentials.keys()])
133
+
134
+ top_asns = []
135
+ for layer in range(num_layers - layers_away):
136
+ if per == "layer":
137
+ potentials = []
138
+ for head_idx in range(num_heads):
139
+ potentials.append(shift_potentials[(layer, layer + layers_away, head_idx)])
140
+ potentials = torch.max(torch.stack(potentials, dim = 0), dim = 0).values
141
+ _, sorted_indices = torch.sort(potentials, descending = True)
142
+ top_asns.append(sorted_indices[:top_k].tolist())
143
+ elif per == "head":
144
+ top_layer_asns = []
145
+ for head_idx in range(num_heads):
146
+ _, sorted_indices = torch.sort(shift_potentials[(layer, layer + layers_away, head_idx)], descending = True)
147
+ top_layer_asns.append(sorted_indices[:top_k].tolist())
148
+ top_asns.append(top_layer_asns)
149
+ else:
150
+ raise ValueError(f"Invalid per value: {per}")
151
+ return top_asns
152
+
153
+ def aggregate_attn_map(attn_map, layer, head):
154
+ num_tokens = attn_map.shape[-1]
155
+ assert (num_tokens - 1) ** 0.5 % 1 == 0, "num_tokens - 1 is not a perfect square"
156
+
157
+ num_patches = int((num_tokens - 1) ** 0.5)
158
+ aggregate_scores = torch.sum(attn_map[:, layer, head, 1:, 1:], dim = 1).reshape((1, num_patches, num_patches))
159
+ return aggregate_scores
160
+
161
+ def attn_map_cls_token(attn_map, layer, head):
162
+ # Gets the attention map for the CLS token
163
+ num_tokens = attn_map.shape[-1]
164
+ assert (num_tokens - 1) ** 0.5 % 1 == 0, "num_tokens - 1 is not a perfect square"
165
+
166
+ num_patches = int((num_tokens - 1) ** 0.5)
167
+ attn_map_reshaped = attn_map[:, layer, head, 0, 1:].reshape((1, num_patches, num_patches))
168
+ return attn_map_reshaped
169
+
170
+ def visualize_attn_shift(attn_map1, attn_map2, image, display=True, out=None, min_diff=None, max_diff=None):
171
+ import matplotlib.pyplot as plt
172
+ import numpy as np
173
+
174
+ # Subtract attn_map1 from attn_map2
175
+ diff_map = attn_map2 - attn_map1
176
+
177
+ # Convert the image to RGBA
178
+ image = image.convert("RGBA")
179
+ overlay = Image.new("RGBA", image.size, (0, 0, 0, 0))
180
+ draw = ImageDraw.Draw(overlay)
181
+
182
+ # Calculate the size of each attention block
183
+ block_size_x = image.size[0] / diff_map.shape[0]
184
+ block_size_y = image.size[1] / diff_map.shape[1]
185
+
186
+ # Create a colormap
187
+ cmap = plt.get_cmap('coolwarm_r') # 'cool' colormap for lighter to darker
188
+
189
+ # Get the min and max values for scaling the colormap
190
+ if max_diff is None:
191
+ max_diff = diff_map.max()
192
+ if min_diff is None:
193
+ min_diff = diff_map.min()
194
+
195
+ for i in range(diff_map.shape[0]):
196
+ for j in range(diff_map.shape[1]):
197
+ # Get the color from the colormap
198
+ intensity = diff_map[i, j]
199
+ normalized_intensity = (intensity - min_diff) / (max_diff - min_diff) # Scale to [0, 1]
200
+ rgba_color = cmap(1 - normalized_intensity) # Invert the normalized intensity
201
+ color = tuple(int(c * 255) for c in rgba_color[:3]) + (int(rgba_color[3] * 128),)
202
+
203
+ # Draw the rectangle on the overlay with transparency
204
+ draw.rectangle(
205
+ [j * block_size_x, i * block_size_y, (j + 1) * block_size_x, (i + 1) * block_size_y],
206
+ fill=color # Add transparency to the color
207
+ )
208
+
209
+ # Composite the overlay with the original image
210
+ combined = Image.alpha_composite(image, overlay)
211
+
212
+ if display:
213
+ # Display the result
214
+ combined.show()
215
+
216
+ # Show the color scale
217
+ plt.figure(figsize=(6, 1))
218
+ plt.imshow([np.linspace(min_diff, max_diff, 256)], cmap='coolwarm_r', aspect='auto')
219
+ plt.gca().set_visible(False)
220
+ plt.colorbar(orientation="horizontal")
221
+ plt.show()
222
+
223
+ if out is not None:
224
+ combined.save(out)
225
+
226
+ return combined
227
+
228
+ def visualize_attn_shift_binary(attn_map1, attn_map2, image, display=True, out=None):
229
+ # Creates a visualization of the attention shift where green = positive, red = negative values.
230
+ # This is useful when there are outliers in the difference map causing the middle values around 0 to be messed into one color
231
+ # Subtract attn_map1 from attn_map2
232
+ diff_map = attn_map2 - attn_map1
233
+
234
+ # Normalize the difference map to range [0, 1] for visualization
235
+ diff_map_normalized = (diff_map - diff_map.min()) / (diff_map.max() - diff_map.min())
236
+ # Convert the image to RGBA
237
+ image = image.convert("RGBA")
238
+ overlay = Image.new("RGBA", image.size, (0, 0, 0, 0))
239
+ draw = ImageDraw.Draw(overlay)
240
+
241
+ # Calculate the size of each attention block
242
+ block_size_x = image.size[0] / diff_map.shape[0]
243
+ block_size_y = image.size[1] / diff_map.shape[1]
244
+
245
+ for i in range(diff_map.shape[0]):
246
+ for j in range(diff_map.shape[1]):
247
+ # Calculate the color intensity based on the difference
248
+ intensity = diff_map_normalized[i, j]
249
+ alpha = int(255 * 0.5) # Tone down the alpha to 50%
250
+ if diff_map[i, j] > 0:
251
+ color = (0, int(255 * intensity), 0, alpha) # Green for positive
252
+ else:
253
+ color = (int(255 * (1 - intensity)), 0, 0, alpha) # Red for negative
254
+
255
+ # Draw the rectangle on the overlay
256
+ draw.rectangle(
257
+ [j * block_size_x, i * block_size_y, (j + 1) * block_size_x, (i + 1) * block_size_y],
258
+ fill=color
259
+ )
260
+
261
+ # Composite the overlay with the original image
262
+ combined = Image.alpha_composite(image, overlay)
263
+
264
+ if display:
265
+ # Display the result
266
+ combined.show()
267
+
268
+ if out is not None:
269
+ combined.save(out)
270
+
271
+ return combined
272
+
273
+ def is_outlier(mean, std, value):
274
+ return value < mean - 2 * std or value > mean + 2 * std
275
+
276
+
277
+ def get_neuron_activations(images, prs_group, model, device = "cuda:0"):
278
+ # Returns neuron activations in shape (num_images, num_layers, num_patches, num_neurons)
279
+ random_neuron_acts = []
280
+ for image in tqdm(images, desc="Processing images"):
281
+ prs_group.reinit()
282
+ image_input = image.unsqueeze(0).to(device)
283
+ representation = model.encode_image(
284
+ image_input, attn_method="head", normalize=False
285
+ )
286
+ prs_group.finalize()
287
+ gelu_outs = prs_group.post_gelu_outputs()
288
+ random_neuron_acts.append(gelu_outs)
289
+ random_neuron_acts = torch.stack(random_neuron_acts, dim = 0)
290
+ return random_neuron_acts
291
+
292
+ def normalize_array(arr):
293
+ min_val = np.min(arr)
294
+ max_val = np.max(arr)
295
+ # Avoid division by zero if all values are the same
296
+ if max_val - min_val == 0:
297
+ return np.zeros_like(arr)
298
+ normalized_arr = (arr - min_val) / (max_val - min_val)
299
+ return normalized_arr
300
+
301
+ def np_l2(arr1, arr2):
302
+ return np.linalg.norm(arr1 - arr2)
303
+
304
+ def best_class(classifier, representation):
305
+ cs = torch.cosine_similarity(classifier, representation.permute(1, 0), dim = 0)
306
+ return torch.argmax(cs).item(), cs[torch.argmax(cs).item()].item()
307
+
308
+ def load_group_attn_shifts(timestamp):
309
+ # Load from Supp1B
310
+ results_dir = "./results/supp1B"
311
+ # dirs = [os.path.join(results_dir, d) for d in os.listdir(results_dir)
312
+ # if os.path.isdir(os.path.join(results_dir, d))]
313
+ # latest_dir = max(dirs, key=os.path.getmtime)
314
+
315
+ latest_dir = os.path.join(results_dir, timestamp)
316
+ print(f"Using latest results directory: {latest_dir}")
317
+
318
+ # Load metadata
319
+ with open(os.path.join(latest_dir, "metadata.json"), "r") as f:
320
+ metadata = json.load(f)
321
+
322
+ # Load memory-mapped files
323
+ attn_maps = np.memmap(os.path.join(latest_dir, "attention_maps.mmap"),
324
+ dtype=np.float32,
325
+ mode='r',
326
+ shape=tuple(metadata["attention_maps_shape"]))
327
+
328
+ resblocks = np.memmap(os.path.join(latest_dir, "resblocks.mmap"),
329
+ dtype=np.float32,
330
+ mode='r',
331
+ shape=tuple(metadata["resblocks_shape"]))
332
+
333
+ # Get file list from metadata
334
+ file_list = metadata.get("file_list", [])
335
+
336
+ # Get top_k values
337
+ top_k_values = metadata.get("top_k_values", [0])
338
+
339
+ return {
340
+ "attn_maps": attn_maps,
341
+ "resblocks": resblocks,
342
+ "metadata": metadata,
343
+ "file_list": file_list,
344
+ "top_k_values": top_k_values,
345
+ "num_layers": metadata.get("num_layers", 0),
346
+ "num_images": metadata.get("num_images", 0),
347
+ "num_heads": metadata.get("num_heads", 0)
348
+ }
349
+
350
+ def load_individual_attn_shifts(timestamp, supp = "supp1D"):
351
+ results_dir = f"./results/{supp}"
352
+ # dirs = [os.path.join(results_dir, d) for d in os.listdir(results_dir)
353
+ # if os.path.isdir(os.path.join(results_dir, d))]
354
+ # latest_dir = max(dirs, key=os.path.getmtime)
355
+
356
+ latest_dir = os.path.join(results_dir, timestamp)
357
+ print(f"Using latest results directory: {latest_dir}")
358
+
359
+ # Load metadata
360
+ with open(os.path.join(latest_dir, "metadata.json"), "r") as f:
361
+ metadata = json.load(f)
362
+
363
+ # Load memory-mapped files
364
+ attn_maps = np.memmap(os.path.join(latest_dir, "attention_maps.mmap"),
365
+ dtype=np.float32,
366
+ mode='r',
367
+ shape=tuple(metadata["attention_maps_shape"]))
368
+
369
+ baseline_attn_maps = np.memmap(os.path.join(latest_dir, "baseline_attention_maps.mmap"),
370
+ dtype=np.float32,
371
+ mode='r',
372
+ shape=tuple(metadata["baseline_attention_maps_shape"]))
373
+
374
+ neuron_activations = np.memmap(os.path.join(latest_dir, "neuron_activations.mmap"),
375
+ dtype=np.float32,
376
+ mode='r',
377
+ shape=tuple(metadata["neuron_activations_shape"]))
378
+
379
+ baseline_neuron_activations = np.memmap(os.path.join(latest_dir, "baseline_neuron_activations.mmap"),
380
+ dtype=np.float32,
381
+ mode='r',
382
+ shape=tuple(metadata["baseline_neuron_activations_shape"]))
383
+
384
+ ablated_neurons = np.memmap(os.path.join(latest_dir, "ablated_neurons.mmap"),
385
+ dtype=np.float32,
386
+ mode='r',
387
+ shape=tuple(metadata["ablated_neurons_shape"]))
388
+
389
+ # Get file list from metadata
390
+ file_list = metadata.get("file_list", [])
391
+
392
+ # Get k value
393
+ k = metadata.get("k", 25)
394
+
395
+ return {
396
+ "attn_maps": attn_maps,
397
+ "baseline_attn_maps": baseline_attn_maps,
398
+ "neuron_activations": neuron_activations,
399
+ "baseline_neuron_activations": baseline_neuron_activations,
400
+ "ablated_neurons": ablated_neurons,
401
+ "metadata": metadata,
402
+ "file_list": file_list,
403
+ "k": k,
404
+ "num_layers": metadata.get("num_layers", 12),
405
+ "num_images": metadata.get("num_images", 100),
406
+ "model_name": metadata.get("model_name", "ViT-B-16"),
407
+ "pretrained": metadata.get("pretrained", "openai")
408
+ }
409
+
410
+ def find_register_neurons_cuda(model, preprocess, prs_group, register_norm_threshold = 30, highest_layer = -1, device = "cuda:0", processed_image_cnt = 500):
411
+ num_layers = len(model.visual.transformer.resblocks)
412
+ highest_layer = num_layers - 1 if highest_layer == -1 else highest_layer
413
+ num_neurons = model.visual.transformer.resblocks[0].mlp.state_dict()["c_proj.weight"].shape[1]
414
+ random_images = load_images(preprocess, count=processed_image_cnt)
415
+ neuron_scores = torch.zeros((len(random_images), num_layers, num_neurons), device=device)
416
+ alignment_scores = torch.zeros((len(random_images), num_layers, num_neurons), device=device)
417
+ image_count = 0
418
+
419
+ for i in tqdm(range(len(random_images)), desc="Processing random images"):
420
+ image = random_images[i].unsqueeze(0).to(device)
421
+ prs_group.reinit()
422
+
423
+ with torch.inference_mode():
424
+ representation = model.encode_image(
425
+ image, attn_method="head", normalize=False
426
+ )
427
+ prs_group.finalize()
428
+
429
+ baseline_neuron_acts = prs_group.post_gelu_outputs().to(device)
430
+ baseline_resblock_outputs = prs_group.resblock_outputs().to(device)
431
+
432
+ # Calculate norm map using torch
433
+ norm_map = torch.norm(baseline_resblock_outputs[-1], dim=1)
434
+ filtered_norms = norm_map.clone()
435
+ filtered_norms[filtered_norms < register_norm_threshold] = 0
436
+
437
+ # Get register locations as a tensor
438
+ register_locations = torch.where(filtered_norms > register_norm_threshold)[0]
439
+
440
+ if len(register_locations) == 0:
441
+ continue
442
+
443
+ image_count += 1
444
+
445
+ # Process all layers vectorized
446
+ for layer in range(num_layers):
447
+ # Get absolute activations for all neurons in this layer
448
+ act_layer = torch.abs(baseline_neuron_acts[layer]) # Shape: [seq_len, num_neurons]
449
+
450
+ # Check sparsity condition for all neurons at once
451
+ sparse_neurons = torch.sum(act_layer < 0.5, dim=0) >= 0.5 * act_layer.shape[0] # Shape: [num_neurons]
452
+
453
+ # Skip computation if no neurons meet the condition
454
+ if not torch.any(sparse_neurons):
455
+ continue
456
+
457
+ # Get values at register locations for all neurons simultaneously
458
+ # This creates a tensor of shape [num_register_locations, num_neurons]
459
+ register_values = act_layer[register_locations]
460
+
461
+ # For neurons that pass sparsity condition, compute mean at register locations
462
+ # First, compute mean for all neurons (this is fast)
463
+ neuron_means = register_values.mean(dim=0) # Shape: [num_neurons]
464
+
465
+ # Then zero out means for neurons that don't pass sparsity condition
466
+ neuron_means = neuron_means * sparse_neurons.float()
467
+
468
+ # Store in score tensor
469
+ neuron_scores[i, layer] = neuron_means
470
+
471
+ # Rest of the code remains the same
472
+ mean_neuron_scores = neuron_scores[:image_count].mean(dim=0)
473
+ mean_alignment_scores = alignment_scores[:image_count].mean(dim=0)
474
+
475
+ # Flatten and find top values
476
+ flattened_scores = mean_neuron_scores.flatten()
477
+ sorted_values, sorted_indices = torch.sort(flattened_scores, descending=True)
478
+
479
+ flattened_alignment = mean_alignment_scores.flatten()
480
+ sorted_alignment_values, sorted_alignment_indices = torch.sort(flattened_alignment, descending=True)
481
+
482
+ # Convert indices to layer/neuron pairs
483
+ top_indices = [(idx.item() // num_neurons, idx.item() % num_neurons) for idx in sorted_indices]
484
+ top_alignment_indices = [(idx.item() // num_neurons, idx.item() % num_neurons) for idx in sorted_alignment_indices]
485
+
486
+ register_norms = [
487
+ (layer, neuron, sorted_values[i].item())
488
+ for i, (layer, neuron) in enumerate(top_indices)
489
+ if layer <= highest_layer
490
+ ]
491
+
492
+ best_alignment_scores = [
493
+ (layer, neuron, sorted_alignment_values[i].item())
494
+ for i, (layer, neuron) in enumerate(top_alignment_indices)
495
+ if layer <= highest_layer
496
+ ]
497
+
498
+ return register_norms, best_alignment_scores
499
+
500
+ def find_register_neurons(model, preprocess, prs_group, register_norm_threshold = 30, highest_layer = -1, device = "cuda:0", processed_image_cnt = 500):
501
+ num_layers = len(model.visual.transformer.resblocks)
502
+ highest_layer = num_layers - 1 if highest_layer == -1 else highest_layer
503
+ num_neurons = model.visual.transformer.resblocks[0].mlp.state_dict()["c_proj.weight"].shape[1]
504
+
505
+ random_images = load_images(preprocess, count = processed_image_cnt)
506
+ neuron_scores = torch.zeros((len(random_images), num_layers, num_neurons))
507
+ for i in tqdm(range(len(random_images)), desc="Processing random images"):
508
+ image = random_images[i].unsqueeze(0).to(device)
509
+
510
+ prs_group.reinit()
511
+ with torch.no_grad():
512
+ representation = model.encode_image(
513
+ image, attn_method="head", normalize=False
514
+ )
515
+ prs_group.finalize()
516
+
517
+ # Gather neuron activations and resblock outputs
518
+ baseline_neuron_acts = prs_group.post_gelu_outputs().cpu().numpy()
519
+ baseline_resblock_outputs = prs_group.resblock_outputs().cpu().numpy()
520
+
521
+ # Calculate norms of the last resblock outputs. Only consider patches of the activation maps that correspond with registers
522
+ norms = np.linalg.norm(baseline_resblock_outputs[-1], axis=1)
523
+ norms[norms < register_norm_threshold] = 0
524
+ register_locations = np.where(norms > register_norm_threshold)[0]
525
+
526
+ # register_neurons = []
527
+ for layer in range(num_layers):
528
+ for neuron in range(num_neurons):
529
+ neuron_map = baseline_neuron_acts[layer, :, neuron]
530
+ mask = np.zeros_like(neuron_map, dtype=bool)
531
+ mask[register_locations] = True
532
+ neuron_map[~mask] = 0
533
+ if np.any(neuron_map < 0):
534
+ continue
535
+ # dist = np.linalg.norm(normalize_array(norms) - normalize_array(neuron_map))
536
+ # register_neurons.append((layer, neuron, dist.item(), neuron_map[register_locations].mean()))
537
+
538
+ neuron_scores[i, layer, neuron] = torch.tensor(neuron_map[register_locations].mean())
539
+ mean_neuron_scores = neuron_scores.mean(dim=0)
540
+ # Flatten the 2D tensor to find global top values
541
+ flattened_scores = mean_neuron_scores.flatten()
542
+ sorted_values, sorted_indices = torch.sort(flattened_scores, descending=True)
543
+
544
+ # Convert flat indices back to 2D coordinates (layer, neuron)
545
+ top_indices = [(idx.item() // num_neurons, idx.item() % num_neurons) for idx in sorted_indices]
546
+
547
+ return [(layer, neuron, sorted_values[i].item()) for i, (layer, neuron) in enumerate(top_indices) if layer <= highest_layer]
548
+
549
+
550
+ def plot_attn_maps(attn_maps, image_idx):
551
+
552
+ num_layers, num_heads, patch_height, patch_width = attn_maps.shape
553
+ print(f"Shape of image_shifts: {attn_maps.shape}")
554
+
555
+ # Create a grid of plots for all layers and heads
556
+ fig, axes = plt.subplots(num_layers, num_heads, figsize=(2*num_heads, 2*num_layers))
557
+ fig.suptitle(f'Attention Shift Maps for Image #{image_idx}', fontsize=16)
558
+
559
+ # Import the correct module for make_axes_locatable
560
+ from mpl_toolkits.axes_grid1 import make_axes_locatable
561
+
562
+ # Plot each layer-head combination
563
+ for layer in range(num_layers):
564
+ # Determine min and max for this layer for consistent colorbar scaling within the layer
565
+ layer_vmin = attn_maps[layer].min().item()
566
+ layer_vmax = attn_maps[layer].max().item()
567
+
568
+ for head in range(num_heads):
569
+ # Get the current axis (handle both 2D and 1D cases)
570
+ if num_layers == 1 and num_heads == 1:
571
+ ax = axes
572
+ elif num_layers == 1:
573
+ ax = axes[head]
574
+ elif num_heads == 1:
575
+ ax = axes[layer]
576
+ else:
577
+ ax = axes[layer, head]
578
+
579
+ # Plot the attention shift map with layer-specific normalization
580
+ im = ax.imshow(attn_maps[layer, head], cmap='viridis', vmin=layer_vmin, vmax=layer_vmax)
581
+
582
+ # Remove ticks for cleaner appearance
583
+ ax.set_xticks([])
584
+ ax.set_yticks([])
585
+
586
+ # Add layer and head labels only on the edges
587
+ if head == 0:
588
+ ax.set_ylabel(f'Layer {layer}')
589
+ if layer == num_layers-1:
590
+ ax.set_xlabel(f'Head {head}')
591
+
592
+ # Add a colorbar for each layer (only once per row)
593
+ if head == num_heads-1:
594
+ # Create a colorbar that's properly sized relative to the plot
595
+ divider = make_axes_locatable(ax)
596
+ cax = divider.append_axes("right", size="5%", pad=0.05)
597
+ plt.colorbar(im, cax=cax)
598
+
599
+ # Adjust layout to make room for the colorbars
600
+ plt.tight_layout()
601
+ return plt
602
+
603
+ def calculate_iou(output, target):
604
+ intersection = output * (output == target)
605
+ area_inter = intersection.sum().item()
606
+ area_pred = output.sum().item()
607
+ area_target = target.sum().item()
608
+ union = area_pred + area_target - area_inter
609
+ iou = area_inter / union
610
+ return area_inter, union, iou
611
+
612
+ def calculate_pixel_accuracy(output, target):
613
+ correct = output * (output == target)
614
+ correct = correct.sum().item()
615
+ total = target.sum().item()
616
+ return correct, total, correct / total
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bos_token": {"content": "<|startoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "pad_token": "<|endoftext|>"}
timm_model.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ timm model adapter
2
+
3
+ Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
4
+ """
5
+ import logging
6
+ from collections import OrderedDict
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ try:
12
+ import timm
13
+ from timm.models.layers import Mlp, to_2tuple
14
+ try:
15
+ # old timm imports < 0.8.1
16
+ from timm.models.layers.attention_pool2d import RotAttentionPool2d
17
+ from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d
18
+ except ImportError:
19
+ # new timm imports >= 0.8.1
20
+ from timm.layers import RotAttentionPool2d
21
+ from timm.layers import AttentionPool2d as AbsAttentionPool2d
22
+ except ImportError:
23
+ timm = None
24
+
25
+ from misc import freeze_batch_norm_2d
26
+
27
+
28
+ class TimmModel(nn.Module):
29
+ """ timm model adapter
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ model_name,
35
+ embed_dim,
36
+ image_size=224,
37
+ pool='avg',
38
+ proj='linear',
39
+ proj_bias=False,
40
+ drop=0.,
41
+ drop_path=None,
42
+ patch_drop=None,
43
+ pretrained=False,
44
+ ):
45
+ super().__init__()
46
+ if timm is None:
47
+ raise RuntimeError("Please `pip install timm` to use timm models.")
48
+ self.image_size = to_2tuple(image_size)
49
+
50
+ # setup kwargs that may not be common across all models
51
+ timm_kwargs = {}
52
+ if drop_path is not None:
53
+ timm_kwargs['drop_path_rate'] = drop_path
54
+ if patch_drop is not None:
55
+ timm_kwargs['patch_drop_rate'] = patch_drop
56
+
57
+ custom_pool = pool in ('abs_attn', 'rot_attn')
58
+ if not proj and not custom_pool:
59
+ # use network classifier head as projection if no proj specified and no custom pooling used
60
+ self.trunk = timm.create_model(
61
+ model_name,
62
+ num_classes=embed_dim,
63
+ global_pool=pool,
64
+ pretrained=pretrained,
65
+ **timm_kwargs,
66
+ )
67
+ prev_chs = embed_dim
68
+ else:
69
+ self.trunk = timm.create_model(
70
+ model_name,
71
+ pretrained=pretrained,
72
+ **timm_kwargs,
73
+ )
74
+ feat_size = self.trunk.default_cfg.get('pool_size', None)
75
+ feature_ndim = 1 if not feat_size else 2
76
+ if custom_pool:
77
+ assert feature_ndim == 2
78
+ # if attn pooling used, remove both classifier and default pool
79
+ self.trunk.reset_classifier(0, global_pool='')
80
+ else:
81
+ # reset global pool if pool config set, otherwise leave as network default
82
+ reset_kwargs = dict(global_pool=pool) if pool else {}
83
+ self.trunk.reset_classifier(0, **reset_kwargs)
84
+ prev_chs = self.trunk.num_features
85
+
86
+ head_layers = OrderedDict()
87
+
88
+ # Add custom pooling to head
89
+ if pool == 'abs_attn':
90
+ head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim)
91
+ prev_chs = embed_dim
92
+ elif pool == 'rot_attn':
93
+ head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
94
+ prev_chs = embed_dim
95
+
96
+ # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
97
+ if proj == 'linear':
98
+ head_layers['drop'] = nn.Dropout(drop)
99
+ head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias)
100
+ elif proj == 'mlp':
101
+ head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias))
102
+ else:
103
+ assert not proj, f'Unknown projection type {proj}.'
104
+
105
+ self.head = nn.Sequential(head_layers)
106
+
107
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
108
+ """ lock modules
109
+ Args:
110
+ unlocked_groups (int): leave last n layer groups unlocked (default: 0)
111
+ """
112
+ if not unlocked_groups:
113
+ # lock full model
114
+ for param in self.trunk.parameters():
115
+ param.requires_grad = False
116
+ if freeze_bn_stats:
117
+ freeze_batch_norm_2d(self.trunk)
118
+ else:
119
+ # NOTE: partial freeze requires latest timm (master) branch and is subject to change
120
+ try:
121
+ # FIXME import here until API stable and in an official release
122
+ from timm.models.helpers import group_parameters, group_modules
123
+ except ImportError:
124
+ raise RuntimeError(
125
+ 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`')
126
+ matcher = self.trunk.group_matcher()
127
+ gparams = group_parameters(self.trunk, matcher)
128
+ max_layer_id = max(gparams.keys())
129
+ max_layer_id = max_layer_id - unlocked_groups
130
+ for group_idx in range(max_layer_id + 1):
131
+ group = gparams[group_idx]
132
+ for param in group:
133
+ self.trunk.get_parameter(param).requires_grad = False
134
+ if freeze_bn_stats:
135
+ gmodules = group_modules(self.trunk, matcher, reverse=True)
136
+ gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
137
+ freeze_batch_norm_2d(self.trunk, gmodules)
138
+
139
+ @torch.jit.ignore
140
+ def set_grad_checkpointing(self, enable=True):
141
+ try:
142
+ self.trunk.set_grad_checkpointing(enable)
143
+ except Exception as e:
144
+ logging.warning('grad checkpointing not supported for this timm image tower, continuing without...')
145
+
146
+ def forward(self, x):
147
+ x = self.trunk(x)
148
+ x = self.head(x)
149
+ return x
tokenizer.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ CLIP tokenizer
2
+
3
+ Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
+ """
5
+ import gzip
6
+ import html
7
+ import os
8
+ from functools import lru_cache
9
+ from typing import Union, List
10
+
11
+ import ftfy
12
+ import regex as re
13
+ import torch
14
+
15
+ # https://stackoverflow.com/q/62691279
16
+ import os
17
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
18
+
19
+
20
+ @lru_cache()
21
+ def default_bpe():
22
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "vocab/bpe_simple_vocab_16e6.txt.gz")
23
+
24
+
25
+ @lru_cache()
26
+ def bytes_to_unicode():
27
+ """
28
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
29
+ The reversible bpe codes work on unicode strings.
30
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
31
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
32
+ This is a significant percentage of your normal, say, 32K bpe vocab.
33
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
34
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
35
+ """
36
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
37
+ cs = bs[:]
38
+ n = 0
39
+ for b in range(2**8):
40
+ if b not in bs:
41
+ bs.append(b)
42
+ cs.append(2**8+n)
43
+ n += 1
44
+ cs = [chr(n) for n in cs]
45
+ return dict(zip(bs, cs))
46
+
47
+
48
+ def get_pairs(word):
49
+ """Return set of symbol pairs in a word.
50
+ Word is represented as tuple of symbols (symbols being variable-length strings).
51
+ """
52
+ pairs = set()
53
+ prev_char = word[0]
54
+ for char in word[1:]:
55
+ pairs.add((prev_char, char))
56
+ prev_char = char
57
+ return pairs
58
+
59
+
60
+ def basic_clean(text):
61
+ text = ftfy.fix_text(text)
62
+ text = html.unescape(html.unescape(text))
63
+ return text.strip()
64
+
65
+
66
+ def whitespace_clean(text):
67
+ text = re.sub(r'\s+', ' ', text)
68
+ text = text.strip()
69
+ return text
70
+
71
+
72
+ class SimpleTokenizer(object):
73
+ def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
74
+ self.byte_encoder = bytes_to_unicode()
75
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
76
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
77
+ merges = merges[1:49152-256-2+1]
78
+ merges = [tuple(merge.split()) for merge in merges]
79
+ vocab = list(bytes_to_unicode().values())
80
+ vocab = vocab + [v+'</w>' for v in vocab]
81
+ for merge in merges:
82
+ vocab.append(''.join(merge))
83
+ if not special_tokens:
84
+ special_tokens = ['<start_of_text>', '<end_of_text>']
85
+ else:
86
+ special_tokens = ['<start_of_text>', '<end_of_text>'] + special_tokens
87
+ vocab.extend(special_tokens)
88
+ self.encoder = dict(zip(vocab, range(len(vocab))))
89
+ self.decoder = {v: k for k, v in self.encoder.items()}
90
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
91
+ self.cache = {t:t for t in special_tokens}
92
+ special = "|".join(special_tokens)
93
+ self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
94
+
95
+ self.vocab_size = len(self.encoder)
96
+ self.all_special_ids = [self.encoder[t] for t in special_tokens]
97
+
98
+ def bpe(self, token):
99
+ if token in self.cache:
100
+ return self.cache[token]
101
+ word = tuple(token[:-1]) + ( token[-1] + '</w>',)
102
+ pairs = get_pairs(word)
103
+
104
+ if not pairs:
105
+ return token+'</w>'
106
+
107
+ while True:
108
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
109
+ if bigram not in self.bpe_ranks:
110
+ break
111
+ first, second = bigram
112
+ new_word = []
113
+ i = 0
114
+ while i < len(word):
115
+ try:
116
+ j = word.index(first, i)
117
+ new_word.extend(word[i:j])
118
+ i = j
119
+ except:
120
+ new_word.extend(word[i:])
121
+ break
122
+
123
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
124
+ new_word.append(first+second)
125
+ i += 2
126
+ else:
127
+ new_word.append(word[i])
128
+ i += 1
129
+ new_word = tuple(new_word)
130
+ word = new_word
131
+ if len(word) == 1:
132
+ break
133
+ else:
134
+ pairs = get_pairs(word)
135
+ word = ' '.join(word)
136
+ self.cache[token] = word
137
+ return word
138
+
139
+ def encode(self, text):
140
+ bpe_tokens = []
141
+ text = whitespace_clean(basic_clean(text)).lower()
142
+ for token in re.findall(self.pat, text):
143
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
144
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
145
+ return bpe_tokens
146
+
147
+ def decode(self, tokens):
148
+ text = ''.join([self.decoder[token] for token in tokens])
149
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
150
+ return text
151
+
152
+
153
+ _tokenizer = SimpleTokenizer()
154
+
155
+ def decode(output_ids: torch.Tensor):
156
+ output_ids = output_ids.cpu().numpy()
157
+ return _tokenizer.decode(output_ids)
158
+
159
+ def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
160
+ """
161
+ Returns the tokenized representation of given input string(s)
162
+
163
+ Parameters
164
+ ----------
165
+ texts : Union[str, List[str]]
166
+ An input string or a list of input strings to tokenize
167
+ context_length : int
168
+ The context length to use; all CLIP models use 77 as the context length
169
+
170
+ Returns
171
+ -------
172
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
173
+ """
174
+ if isinstance(texts, str):
175
+ texts = [texts]
176
+
177
+ sot_token = _tokenizer.encoder["<start_of_text>"]
178
+ eot_token = _tokenizer.encoder["<end_of_text>"]
179
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
180
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
181
+
182
+ for i, tokens in enumerate(all_tokens):
183
+ if len(tokens) > context_length:
184
+ tokens = tokens[:context_length] # Truncate
185
+ tokens[-1] = eot_token
186
+ result[i, :len(tokens)] = torch.tensor(tokens)
187
+
188
+ return result
189
+
190
+
191
+ class HFTokenizer:
192
+ """HuggingFace tokenizer wrapper"""
193
+
194
+ def __init__(self, tokenizer_name: str):
195
+ from transformers import AutoTokenizer
196
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
197
+
198
+ def save_pretrained(self, dest):
199
+ self.tokenizer.save_pretrained(dest)
200
+
201
+ def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.Tensor:
202
+ # same cleaning as for default tokenizer, except lowercasing
203
+ # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance
204
+ if isinstance(texts, str):
205
+ texts = [texts]
206
+ texts = [whitespace_clean(basic_clean(text)) for text in texts]
207
+ input_ids = self.tokenizer(
208
+ texts,
209
+ return_tensors='pt',
210
+ max_length=context_length,
211
+ padding='max_length',
212
+ truncation=True,
213
+ ).input_ids
214
+ return input_ids
tokenizer_config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "unk_token": {
3
+ "content": "<|endoftext|>",
4
+ "single_word": false,
5
+ "lstrip": false,
6
+ "rstrip": false,
7
+ "normalized": true,
8
+ "__type": "AddedToken"
9
+ },
10
+ "bos_token": {
11
+ "content": "<|startoftext|>",
12
+ "single_word": false,
13
+ "lstrip": false,
14
+ "rstrip": false,
15
+ "normalized": true,
16
+ "__type": "AddedToken"
17
+ },
18
+ "eos_token": {
19
+ "content": "<|endoftext|>",
20
+ "single_word": false,
21
+ "lstrip": false,
22
+ "rstrip": false,
23
+ "normalized": true,
24
+ "__type": "AddedToken"
25
+ },
26
+ "pad_token": "<|endoftext|>",
27
+ "add_prefix_space": false,
28
+ "errors": "replace",
29
+ "do_lower_case": true,
30
+ "name_or_path": "openai/clip-vit-base-patch32",
31
+ "model_max_length": 77,
32
+ "special_tokens_map_file": "/home/suraj/.cache/huggingface/transformers/18a566598f286c9139f88160c99f84eec492a26bd22738fa9cb44d5b7e0a5c76.cce1206abbad28826f000510f22f354e53e66a97f7c23745a7dfe27609cc07f5",
33
+ "tokenizer_class": "CLIPTokenizer"
34
+ }
tokenizer_config_bak.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "tokenizer_class": "SimpleTokenizer",
3
+ "vocab_size": 49408,
4
+ "context_length": 77,
5
+ "bpe_path": "vocab/bpe_simple_vocab_16e6.txt.gz",
6
+ "special_tokens": [
7
+ "<start_of_text>",
8
+ "<end_of_text>"
9
+ ]
10
+ }
transform.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from dataclasses import dataclass, asdict
3
+ from typing import Any, Dict, Optional, Sequence, Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torchvision.transforms.functional as F
8
+
9
+ from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
10
+ CenterCrop
11
+
12
+ from constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
13
+
14
+
15
+ @dataclass
16
+ class AugmentationCfg:
17
+ scale: Tuple[float, float] = (0.9, 1.0)
18
+ ratio: Optional[Tuple[float, float]] = None
19
+ color_jitter: Optional[Union[float, Tuple[float, float, float]]] = None
20
+ interpolation: Optional[str] = None
21
+ re_prob: Optional[float] = None
22
+ re_count: Optional[int] = None
23
+ use_timm: bool = False
24
+
25
+
26
+ class ResizeMaxSize(nn.Module):
27
+
28
+ def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0):
29
+ super().__init__()
30
+ if not isinstance(max_size, int):
31
+ raise TypeError(f"Size should be int. Got {type(max_size)}")
32
+ self.max_size = max_size
33
+ self.interpolation = interpolation
34
+ self.fn = min if fn == 'min' else min
35
+ self.fill = fill
36
+
37
+ def forward(self, img):
38
+ if isinstance(img, torch.Tensor):
39
+ height, width = img.shape[:2]
40
+ else:
41
+ width, height = img.size
42
+ scale = self.max_size / float(max(height, width))
43
+ if scale != 1.0:
44
+ new_size = tuple(round(dim * scale) for dim in (height, width))
45
+ img = F.resize(img, new_size, self.interpolation)
46
+ pad_h = self.max_size - new_size[0]
47
+ pad_w = self.max_size - new_size[1]
48
+ img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill)
49
+ return img
50
+
51
+
52
+ def _convert_to_rgb(image):
53
+ return image.convert('RGB')
54
+
55
+
56
+ def image_transform(
57
+ image_size: int,
58
+ is_train: bool,
59
+ mean: Optional[Tuple[float, ...]] = None,
60
+ std: Optional[Tuple[float, ...]] = None,
61
+ resize_longest_max: bool = False,
62
+ fill_color: int = 0,
63
+ aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
64
+ ):
65
+ mean = mean or OPENAI_DATASET_MEAN
66
+ if not isinstance(mean, (list, tuple)):
67
+ mean = (mean,) * 3
68
+
69
+ std = std or OPENAI_DATASET_STD
70
+ if not isinstance(std, (list, tuple)):
71
+ std = (std,) * 3
72
+
73
+ if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
74
+ # for square size, pass size as int so that Resize() uses aspect preserving shortest edge
75
+ image_size = image_size[0]
76
+
77
+ if isinstance(aug_cfg, dict):
78
+ aug_cfg = AugmentationCfg(**aug_cfg)
79
+ else:
80
+ aug_cfg = aug_cfg or AugmentationCfg()
81
+ normalize = Normalize(mean=mean, std=std)
82
+ if is_train:
83
+ aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None}
84
+ use_timm = aug_cfg_dict.pop('use_timm', False)
85
+ if use_timm:
86
+ from timm.data import create_transform # timm can still be optional
87
+ if isinstance(image_size, (tuple, list)):
88
+ assert len(image_size) >= 2
89
+ input_size = (3,) + image_size[-2:]
90
+ else:
91
+ input_size = (3, image_size, image_size)
92
+ # by default, timm aug randomly alternates bicubic & bilinear for better robustness at inference time
93
+ aug_cfg_dict.setdefault('interpolation', 'random')
94
+ aug_cfg_dict.setdefault('color_jitter', None) # disable by default
95
+ train_transform = create_transform(
96
+ input_size=input_size,
97
+ is_training=True,
98
+ hflip=0.,
99
+ mean=mean,
100
+ std=std,
101
+ re_mode='pixel',
102
+ **aug_cfg_dict,
103
+ )
104
+ else:
105
+ train_transform = Compose([
106
+ RandomResizedCrop(
107
+ image_size,
108
+ scale=aug_cfg_dict.pop('scale'),
109
+ interpolation=InterpolationMode.BICUBIC,
110
+ ),
111
+ _convert_to_rgb,
112
+ ToTensor(),
113
+ normalize,
114
+ ])
115
+ if aug_cfg_dict:
116
+ warnings.warn(f'Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())}).')
117
+ return train_transform
118
+ else:
119
+ if resize_longest_max:
120
+ transforms = [
121
+ ResizeMaxSize(image_size, fill=fill_color)
122
+ ]
123
+ else:
124
+ transforms = [
125
+ Resize(image_size, interpolation=InterpolationMode.BICUBIC),
126
+ CenterCrop(image_size),
127
+ ]
128
+ transforms.extend([
129
+ _convert_to_rgb,
130
+ ToTensor(),
131
+ normalize,
132
+ ])
133
+ return Compose(transforms)
transformer.py ADDED
@@ -0,0 +1,872 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ import math
3
+ from typing import Callable, Optional, Sequence, Tuple, Text
4
+
5
+ import torch
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+ from torch.utils.checkpoint import checkpoint
9
+ import numbers
10
+ import einops
11
+ import numpy as np
12
+ from misc import to_2tuple
13
+
14
+
15
+ class LayerNorm(nn.Module):
16
+ """Subclass torch's LayerNorm (with cast back to input dtype)."""
17
+
18
+ def __init__(
19
+ self,
20
+ normalized_shape,
21
+ eps: float = 1e-5,
22
+ elementwise_affine: bool = True,
23
+ device=None,
24
+ dtype=None,
25
+ ):
26
+ super().__init__()
27
+ if isinstance(normalized_shape, numbers.Integral):
28
+ normalized_shape = (normalized_shape,)
29
+ self.normalized_shape = tuple(normalized_shape)
30
+ self.eps = eps
31
+ self.elementwise_affine = elementwise_affine
32
+ if self.elementwise_affine:
33
+ self.weight = torch.nn.Parameter(
34
+ torch.empty(self.normalized_shape)
35
+ )
36
+ self.bias = torch.nn.Parameter(
37
+ torch.empty(self.normalized_shape)
38
+ )
39
+ else:
40
+ self.register_parameter("weight", None)
41
+ self.register_parameter("bias", None)
42
+
43
+ def forward(self, x: torch.Tensor):
44
+ orig_type = x.dtype
45
+ assert self.normalized_shape == x.shape[-len(self.normalized_shape):]
46
+ dims = [-(i + 1) for i in range(len(self.normalized_shape))]
47
+ mean = x.mean(dim=dims, keepdim=True)
48
+ mean_x2 = (x**2).mean(dim=dims, keepdim=True)
49
+ var = mean_x2 - mean**2
50
+ x_norm = (x - mean) / torch.sqrt(var + self.eps)
51
+ if self.elementwise_affine:
52
+ x_norm = self.weight * x_norm + self.bias
53
+ return x_norm.to(orig_type)
54
+
55
+
56
+ class QuickGELU(nn.Module):
57
+ def forward(self, x: torch.Tensor):
58
+ return x * torch.sigmoid(1.702 * x)
59
+
60
+
61
+ class LayerScale(nn.Module):
62
+ def __init__(self, dim, init_values=1e-5, inplace=False):
63
+ super().__init__()
64
+ self.inplace = inplace
65
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
66
+
67
+ def forward(self, x):
68
+ raise ValueError("Not implemented")
69
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
70
+
71
+
72
+ class PatchDropout(nn.Module):
73
+ """
74
+ https://arxiv.org/abs/2212.00794
75
+ """
76
+
77
+ def __init__(self, prob, exclude_first_token=True):
78
+ super().__init__()
79
+ assert 0 <= prob < 1.0
80
+ self.prob = prob
81
+ self.exclude_first_token = exclude_first_token
82
+
83
+ def forward(self, x):
84
+ if not self.training or self.prob == 0.0:
85
+ return x
86
+
87
+ if self.exclude_first_token:
88
+ cls_tokens, x = x[:, :1], x[:, 1:]
89
+ else:
90
+ cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
91
+
92
+ batch = x.size()[0]
93
+ num_tokens = x.size()[1]
94
+
95
+ batch_indices = torch.arange(batch)
96
+ batch_indices = batch_indices[..., None]
97
+
98
+ keep_prob = 1 - self.prob
99
+ num_patches_keep = max(1, int(num_tokens * keep_prob))
100
+
101
+ rand = torch.randn(batch, num_tokens)
102
+ patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
103
+
104
+ x = x[batch_indices, patch_indices_keep]
105
+
106
+ if self.exclude_first_token:
107
+ x = torch.cat((cls_tokens, x), dim=1)
108
+
109
+ return x
110
+
111
+
112
+ class Attention(nn.Module):
113
+ def __init__(
114
+ self,
115
+ dim,
116
+ num_heads=8,
117
+ qkv_bias=True,
118
+ scaled_cosine=False,
119
+ scale_heads=False,
120
+ logit_scale_max=math.log(1.0 / 0.01),
121
+ attn_drop=0.0,
122
+ proj_drop=0.0,
123
+ ):
124
+ super().__init__()
125
+ self.scaled_cosine = scaled_cosine
126
+ self.scale_heads = scale_heads
127
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
128
+ self.num_heads = num_heads
129
+ self.head_dim = dim // num_heads
130
+ self.scale = self.head_dim**-0.5
131
+ self.logit_scale_max = logit_scale_max
132
+
133
+ self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
134
+ if qkv_bias:
135
+ self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
136
+ else:
137
+ self.in_proj_bias = None
138
+
139
+ if self.scaled_cosine:
140
+ self.logit_scale = nn.Parameter(
141
+ torch.log(10 * torch.ones((num_heads, 1, 1)))
142
+ )
143
+ else:
144
+ self.logit_scale = None
145
+ self.attn_drop = nn.Dropout(attn_drop)
146
+ if self.scale_heads:
147
+ self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
148
+ else:
149
+ self.head_scale = None
150
+ self.out_proj = nn.Linear(dim, dim)
151
+ self.out_drop = nn.Dropout(proj_drop)
152
+
153
+ def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
154
+ L, N, C = x.shape
155
+ q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)
156
+ q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
157
+ k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
158
+ v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
159
+
160
+ if self.logit_scale is not None:
161
+ attn = torch.bmm(
162
+ F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2)
163
+ )
164
+ logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
165
+ attn = attn.view(N, self.num_heads, L, L) * logit_scale
166
+ attn = attn.view(-1, L, L)
167
+ else:
168
+ q = q * self.scale
169
+ attn = torch.bmm(q, k.transpose(-1, -2))
170
+
171
+ if attn_mask is not None:
172
+ if attn_mask.dtype == torch.bool:
173
+ new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
174
+ new_attn_mask.masked_fill_(attn_mask, float("-inf"))
175
+ attn_mask = new_attn_mask
176
+ attn += attn_mask
177
+
178
+ attn = attn.softmax(dim=-1)
179
+ attn = self.attn_drop(attn)
180
+
181
+ x = torch.bmm(attn, v)
182
+ if self.head_scale is not None:
183
+ x = x.view(N, self.num_heads, L, C) * self.head_scale
184
+ x = x.view(-1, L, C)
185
+ x = x.transpose(0, 1).reshape(L, N, C)
186
+ x = self.out_proj(x)
187
+ x = self.out_drop(x)
188
+ return x
189
+
190
+
191
+ class AttentionalPooler(nn.Module):
192
+ def __init__(
193
+ self,
194
+ d_model: int,
195
+ context_dim: int,
196
+ n_head: int = 8,
197
+ n_queries: int = 256,
198
+ norm_layer: Callable = LayerNorm,
199
+ ):
200
+ super().__init__()
201
+ self.query = nn.Parameter(torch.randn(n_queries, d_model))
202
+ self.attn = nn.MultiheadAttention(
203
+ d_model, n_head, kdim=context_dim, vdim=context_dim
204
+ )
205
+ self.ln_q = norm_layer(d_model)
206
+ self.ln_k = norm_layer(context_dim)
207
+
208
+ def forward(self, x: torch.Tensor):
209
+ x = self.ln_k(x).permute(1, 0, 2) # NLD -> LND
210
+ N = x.shape[1]
211
+ q = self.ln_q(self.query)
212
+ out = self.attn(self._repeat(q, N), x, x, need_weights=False)[0]
213
+ return out.permute(1, 0, 2) # LND -> NLD
214
+
215
+ def _repeat(self, query, N: int):
216
+ return query.unsqueeze(1).repeat(1, N, 1)
217
+
218
+
219
+ class MLP(nn.Module):
220
+ def __init__(
221
+ self,
222
+ d_model: int,
223
+ mlp_width: int,
224
+ act_layer: Callable = nn.GELU,
225
+ layer_id: Optional[int] = None,
226
+ ):
227
+ super().__init__()
228
+ self.c_fc = nn.Linear(d_model, mlp_width)
229
+ self.gelu = act_layer()
230
+ self.c_proj = nn.Linear(mlp_width, d_model)
231
+ self.layer_id = layer_id
232
+
233
+ def forward(self, x, neuron_dict=None, num_register_tokens=0):
234
+ x = self.c_fc(x)
235
+
236
+ # If we have a dictionary of modifications and this layer is in it
237
+ if neuron_dict is not None and self.layer_id in neuron_dict and num_register_tokens>0:
238
+ neurons = neuron_dict[self.layer_id]
239
+
240
+ # Apply GELU to all activations
241
+ x_after_gelu = self.gelu(x)
242
+
243
+ original_activations = x_after_gelu.clone()
244
+ # Create new activation map for specified neurons
245
+ new_activation_map = torch.zeros(
246
+ (x_after_gelu.shape[0], x_after_gelu.shape[1], len(neurons)),
247
+ device=x_after_gelu.device,
248
+ ).to(x_after_gelu.dtype)
249
+
250
+ max_values = torch.max(original_activations[:, :, neurons], dim=1, keepdim=True).values
251
+
252
+ new_activation_map[:, -num_register_tokens:, :] = max_values
253
+ new_activation_map[:,0,:] = x_after_gelu[:,0,neurons]
254
+
255
+ x_after_gelu[:,:,neurons] = new_activation_map
256
+ x = x_after_gelu
257
+ else:
258
+ x = self.gelu(x)
259
+
260
+ x = self.c_proj(x)
261
+ return x
262
+
263
+ # TODO 여기가 custom attetion이 아니라는 점에서 문제가 발생한 것으로 보인다.
264
+ class MultiheadAttention(nn.Module):
265
+ def __init__(
266
+ self,
267
+ embed_dim,
268
+ num_heads,
269
+ dropout=0.0,
270
+ bias=True,
271
+ add_bias_kv=False,
272
+ add_zero_attn=False,
273
+ kdim=None,
274
+ vdim=None,
275
+ batch_first=False,
276
+ device=None,
277
+ dtype=None,
278
+ ):
279
+ super().__init__()
280
+ self.embed_dim = embed_dim
281
+ self.kdim = kdim if kdim is not None else embed_dim
282
+ self.vdim = vdim if vdim is not None else embed_dim
283
+ self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
284
+ self.q_out = nn.Identity()
285
+ self.k_out = nn.Identity()
286
+ self.v_out = nn.Identity()
287
+ self.qkv_out = nn.Identity()
288
+ self.attn_map = nn.Identity()
289
+
290
+ self.num_heads = num_heads
291
+ self.dropout = dropout
292
+ self.batch_first = batch_first
293
+ self.head_dim = embed_dim // num_heads
294
+ assert (
295
+ self.head_dim * num_heads == self.embed_dim
296
+ ), "embed_dim must be divisible by num_heads"
297
+ self.in_proj_weight = nn.Parameter(torch.empty((3 * embed_dim, embed_dim)))
298
+
299
+ if bias:
300
+ self.in_proj_bias = nn.Parameter(torch.empty(3 * embed_dim))
301
+ else:
302
+ self.register_parameter("in_proj_bias", None)
303
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
304
+
305
+ if add_bias_kv:
306
+ self.bias_k = nn.Parameter(torch.empty((1, 1, embed_dim)))
307
+ self.bias_v = nn.Parameter(torch.empty((1, 1, embed_dim)))
308
+ else:
309
+ self.bias_k = self.bias_v = None
310
+
311
+ self.add_zero_attn = add_zero_attn
312
+
313
+ def forward_direct(self, x, attn_mask=None):
314
+ B, N, C = x.shape
315
+ qkv = x @ self.in_proj_weight.T + self.in_proj_bias
316
+ qkv = self.qkv_out(qkv)
317
+ qkv = qkv.reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
318
+ # B, S, 3, H, d -> 3, B, H, S, d batch first computation
319
+ # 이 지점 때문에 연산 결과에 차이가 생기는 거 같은데?
320
+ q, k, v = qkv.unbind(0)
321
+
322
+ q = self.q_out(q)
323
+ k = self.k_out(k)
324
+ v = self.v_out(v)
325
+
326
+ dk = q.size()[-1]
327
+ q = q / math.sqrt(dk)
328
+ attn = q @ k.transpose(-2, -1)
329
+ if attn_mask is not None:
330
+ attn += attn_mask
331
+ attn = attn.softmax(dim=-1)
332
+ attn = self.attn_map(attn)
333
+ x = attn @ v
334
+
335
+ x = x.transpose(1, 2).reshape(B, N, C)
336
+ x = x @ self.out_proj.weight.T + self.out_proj.bias
337
+ return x
338
+
339
+ def _split_qkv_weight(self):
340
+ q_weight, k_weight, v_weight = (
341
+ self.in_proj_weight[: self.embed_dim].reshape(
342
+ self.num_heads, self.head_dim, -1
343
+ ),
344
+ self.in_proj_weight[self.embed_dim : self.embed_dim * 2].reshape(
345
+ self.num_heads, self.head_dim, -1
346
+ ),
347
+ self.in_proj_weight[self.embed_dim * 2 :].reshape(
348
+ self.num_heads, self.head_dim, -1
349
+ ),
350
+ )
351
+ return q_weight, k_weight, v_weight
352
+
353
+ def _split_qkv_bias(self):
354
+ q_bias, k_bias, v_bias = (
355
+ self.in_proj_bias[: self.embed_dim].reshape(
356
+ 1, self.num_heads, 1, self.head_dim
357
+ ),
358
+ self.in_proj_bias[self.embed_dim : self.embed_dim * 2].reshape(
359
+ 1, self.num_heads, 1, self.head_dim
360
+ ),
361
+ self.in_proj_bias[self.embed_dim * 2 :].reshape(
362
+ 1, self.num_heads, 1, self.head_dim
363
+ ),
364
+ )
365
+ return q_bias, k_bias, v_bias
366
+
367
+ def forward_qkv(self, x, attn_mask=None):
368
+ B, N, C = x.shape
369
+ q_weight, k_weight, v_weight = (
370
+ self.in_proj_weight[: self.embed_dim],
371
+ self.in_proj_weight[self.embed_dim : self.embed_dim * 2],
372
+ self.in_proj_weight[self.embed_dim * 2 :],
373
+ )
374
+ q_bias, k_bias, v_bias = (
375
+ self.in_proj_bias[: self.embed_dim],
376
+ self.in_proj_bias[self.embed_dim : self.embed_dim * 2],
377
+ self.in_proj_bias[self.embed_dim * 2 :],
378
+ )
379
+ q = (x @ q_weight.T + q_bias).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
380
+ k = (x @ k_weight.T + k_bias).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
381
+ v = (x @ v_weight.T + v_bias).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
382
+
383
+ dk = q.size()[-1]
384
+ q = q / math.sqrt(dk)
385
+ attn = q @ k.transpose(-2, -1)
386
+ if attn_mask is not None:
387
+ attn += attn_mask
388
+ attn = attn.softmax(dim=-1)
389
+ x = torch.einsum("bhnm,bhmc->bhnmc", attn, v)
390
+ x = x.sum(axis=3).transpose(1, 2).reshape(B, N, C)
391
+ x = x @ self.out_proj.weight.T + self.out_proj.bias
392
+ return x
393
+
394
+ def forward_per_head(self, x, attn_mask=None):
395
+ B, N, C = x.shape
396
+ q_weight, k_weight, v_weight = self._split_qkv_weight()
397
+ q_bias, k_bias, v_bias = self._split_qkv_bias()
398
+ q = torch.einsum("bnc,hdc->bhnd", x, q_weight) + q_bias
399
+ k = torch.einsum("bnc,hdc->bhnd", x, k_weight) + k_bias
400
+ v = torch.einsum("bnc,hdc->bhnd", x, v_weight) + v_bias
401
+
402
+ dk = q.size()[-1]
403
+ q = q / math.sqrt(dk)
404
+ attn = q @ k.transpose(-2, -1)
405
+ if attn_mask is not None:
406
+ attn += attn_mask
407
+ attn = attn.softmax(dim=-1)
408
+ x = torch.einsum("bhnm,bhmc->bnmhc", attn, v)
409
+ x = torch.einsum(
410
+ "bnmhc,dhc->bnmhd",
411
+ x,
412
+ self.out_proj.weight.reshape(self.embed_dim, self.num_heads, self.head_dim),
413
+ )
414
+ x = x.sum(axis=[2, 3]) + self.out_proj.bias
415
+ return x
416
+
417
+ def _get_ov_circuit(self):
418
+ reshaped_o = self.out_proj.weight.reshape(
419
+ self.embed_dim, self.num_heads, self.head_dim
420
+ )
421
+ _, _, v_weight = self._split_qkv_weight()
422
+ _, _, v_bias = self._split_qkv_bias()
423
+ ov_circuit = torch.einsum("onh,nhi->oni", reshaped_o, v_weight)
424
+ ov_bias_circuit = torch.einsum("onh,bnxh->bnxo", reshaped_o, v_bias)
425
+ return ov_circuit, ov_bias_circuit
426
+
427
+ def forward_ov_circuit(self, x, attn_mask=None):
428
+ B, N, C = x.shape
429
+ q_weight, k_weight, _ = self._split_qkv_weight()
430
+ q_bias, k_bias, _ = self._split_qkv_bias()
431
+ q = torch.einsum("bnc,hdc->bhnd", x, q_weight) + q_bias
432
+ k = torch.einsum("bnc,hdc->bhnd", x, k_weight) + k_bias
433
+ ov, ov_bias = self._get_ov_circuit()
434
+ v = torch.einsum("bnc,dhc->bhnd", x, ov) + ov_bias
435
+
436
+ dk = q.size()[-1]
437
+ q = q / math.sqrt(dk)
438
+ attn = q @ k.transpose(-2, -1)
439
+ if attn_mask is not None:
440
+ attn += attn_mask
441
+ attn = attn.softmax(dim=-1)
442
+ x = torch.einsum("bhnm,bhmc->bnmhc", attn, v)
443
+ x = x.sum(axis=[2, 3]) + self.out_proj.bias
444
+ return x
445
+
446
+ def forward(self, x, attn_mask=None, method: Text = "ov_circuit"):
447
+ if method == "direct":
448
+ return self.forward_direct(x, attn_mask=attn_mask)
449
+ elif method == "qkv":
450
+ return self.forward_qkv(x, attn_mask=attn_mask)
451
+ elif method == "head":
452
+ return self.forward_per_head(x, attn_mask=attn_mask)
453
+ elif method == "ov_circuit":
454
+ return self.forward_ov_circuit(x, attn_mask=attn_mask)
455
+
456
+
457
+ class ResidualAttentionBlock(nn.Module):
458
+ def __init__(
459
+ self,
460
+ d_model: int,
461
+ n_head: int,
462
+ mlp_ratio: float = 4.0,
463
+ ls_init_value: float = None,
464
+ act_layer: Callable = nn.GELU,
465
+ norm_layer: Callable = LayerNorm,
466
+ layer_id: Optional[int] = None,
467
+ ):
468
+ super().__init__()
469
+ self.ln_1 = norm_layer(d_model)
470
+ self.attn = MultiheadAttention(d_model, n_head)
471
+ self.layer_id = layer_id
472
+
473
+ self.ls_1 = (
474
+ LayerScale(d_model, ls_init_value)
475
+ if ls_init_value is not None
476
+ else nn.Identity()
477
+ )
478
+
479
+ self.ln_2 = norm_layer(d_model)
480
+ self.mlp_width = int(d_model * mlp_ratio)
481
+ self.mlp = MLP(
482
+ d_model,
483
+ self.mlp_width,
484
+ act_layer=act_layer,
485
+ layer_id=layer_id,
486
+ )
487
+ self.ls_2 = (
488
+ LayerScale(d_model, ls_init_value)
489
+ if ls_init_value is not None
490
+ else nn.Identity()
491
+ )
492
+
493
+ def attention(
494
+ self,
495
+ q_x: torch.Tensor,
496
+ attn_mask: Optional[torch.Tensor] = None,
497
+ method: Text = "direct",
498
+ ):
499
+ attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None
500
+ return self.attn(q_x, attn_mask=attn_mask, method=method)
501
+
502
+ def forward(
503
+ self,
504
+ q_x: torch.Tensor,
505
+ attn_mask: Optional[torch.Tensor] = None,
506
+ attn_method: Text = "direct",
507
+ neuron_dict=None,
508
+ num_register_tokens=0
509
+ ):
510
+ after_ln1 = self.ln_1(q_x)
511
+ after_attn = self.attention(
512
+ q_x=after_ln1, attn_mask=attn_mask, method=attn_method
513
+ )
514
+ x = q_x + self.ls_1(after_attn)
515
+ after_ln2 = self.ln_2(x)
516
+ after_mlp = self.mlp(after_ln2, neuron_dict=neuron_dict, num_register_tokens=num_register_tokens)
517
+ x = x + self.ls_2(after_mlp)
518
+ return x
519
+
520
+
521
+ class Transformer(nn.Module):
522
+ def __init__(
523
+ self,
524
+ width: int,
525
+ layers: int,
526
+ heads: int,
527
+ mlp_ratio: float = 4.0,
528
+ ls_init_value: float = None,
529
+ act_layer: Callable = nn.GELU,
530
+ norm_layer: Callable = LayerNorm,
531
+ ):
532
+ super().__init__()
533
+ self.width = width
534
+ self.layers = layers
535
+ self.grad_checkpointing = False
536
+
537
+ self.resblocks = nn.ModuleList(
538
+ [
539
+ ResidualAttentionBlock(
540
+ width,
541
+ heads,
542
+ mlp_ratio,
543
+ ls_init_value=ls_init_value,
544
+ act_layer=act_layer,
545
+ norm_layer=norm_layer,
546
+ layer_id=i,
547
+ )
548
+ for i in range(layers)
549
+ ]
550
+ )
551
+
552
+ def get_cast_dtype(self) -> torch.dtype:
553
+ if hasattr(self.resblocks[0].mlp.c_fc, "int8_original_dtype"):
554
+ return self.resblocks[0].mlp.c_fc.int8_original_dtype
555
+ return self.resblocks[0].mlp.c_fc.weight.dtype
556
+
557
+ def forward(
558
+ self,
559
+ x: torch.Tensor,
560
+ attn_mask: Optional[torch.Tensor] = None,
561
+ attn_method: Text = "direct",
562
+ neuron_dict=None,
563
+ num_register_tokens=0
564
+ ):
565
+ for r in self.resblocks:
566
+ if self.grad_checkpointing and not torch.jit.is_scripting():
567
+ raise ValueError("grad_checkpointing not implemented")
568
+ else:
569
+ x = r(
570
+ x,
571
+ attn_mask=attn_mask,
572
+ attn_method=attn_method,
573
+ neuron_dict=neuron_dict,
574
+ num_register_tokens=num_register_tokens
575
+ )
576
+ return x
577
+
578
+
579
+ class VisionTransformer(nn.Module):
580
+ output_tokens: torch.jit.Final[bool]
581
+
582
+ def __init__(
583
+ self,
584
+ image_size: int,
585
+ patch_size: int,
586
+ width: int,
587
+ layers: int,
588
+ heads: int,
589
+ mlp_ratio: float,
590
+ ls_init_value: float = None,
591
+ global_average_pool: bool = False,
592
+ attentional_pool: bool = False,
593
+ n_queries: int = 256,
594
+ attn_pooler_heads: int = 8,
595
+ output_dim: int = 512,
596
+ patch_dropout: float = 0.0,
597
+ input_patchnorm: bool = False,
598
+ act_layer: Callable = nn.GELU,
599
+ norm_layer: Callable = LayerNorm,
600
+ output_tokens: bool = False,
601
+ ):
602
+ super().__init__()
603
+ self.output_tokens = output_tokens
604
+ image_height, image_width = self.image_size = to_2tuple(image_size)
605
+ patch_height, patch_width = self.patch_size = to_2tuple(patch_size)
606
+ self.grid_size = (image_height // patch_height, image_width // patch_width)
607
+ self.output_dim = output_dim
608
+
609
+ self.num_register_tokens = 0
610
+ self.neuron_dict = None
611
+
612
+ self.input_patchnorm = input_patchnorm
613
+
614
+ if input_patchnorm:
615
+ patch_input_dim = patch_height * patch_width * 3
616
+ self.patchnorm_pre_ln = LayerNorm(patch_input_dim)
617
+ self.conv1 = nn.Linear(patch_input_dim, width)
618
+ else:
619
+ self.patchnorm_pre_ln = nn.Identity()
620
+ self.conv1 = nn.Conv2d(
621
+ in_channels=3,
622
+ out_channels=width,
623
+ kernel_size=patch_size,
624
+ stride=patch_size,
625
+ bias=False,
626
+ )
627
+
628
+ scale = width**-0.5
629
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
630
+ self.positional_embedding = nn.Parameter(
631
+ scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width)
632
+ )
633
+
634
+ self.width = width
635
+ self.scale = scale
636
+ self.extra_token = self.scale * torch.randn(width)
637
+
638
+ self.patch_dropout = (
639
+ PatchDropout(patch_dropout) if patch_dropout > 0.0 else nn.Identity()
640
+ )
641
+
642
+ self.ln_pre = norm_layer(width)
643
+ self.transformer = Transformer(
644
+ width,
645
+ layers,
646
+ heads,
647
+ mlp_ratio,
648
+ ls_init_value=ls_init_value,
649
+ act_layer=act_layer,
650
+ norm_layer=norm_layer,
651
+ )
652
+
653
+ self.global_average_pool = global_average_pool
654
+ if attentional_pool:
655
+ self.attn_pool = AttentionalPooler(
656
+ output_dim, width, n_head=attn_pooler_heads, n_queries=n_queries
657
+ )
658
+ self.ln_post = norm_layer(output_dim)
659
+ self.proj = nn.Parameter(scale * torch.randn(output_dim, output_dim))
660
+ else:
661
+ self.attn_pool = None
662
+ self.ln_post = norm_layer(width)
663
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
664
+
665
+ @torch.jit.ignore
666
+ def set_grad_checkpointing(self, enable=True):
667
+ self.transformer.grad_checkpointing = enable
668
+
669
+ def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
670
+ if self.global_average_pool:
671
+ return x.mean(dim=1), x
672
+ else:
673
+ return x[:, 0], x[:, 1:]
674
+
675
+ def forward(self, x: torch.Tensor, attn_method: Text = "direct", num_register_tokens = None, neuron_dict=None):
676
+ # to patches
677
+
678
+ if num_register_tokens is None and neuron_dict is None:
679
+ num_register_tokens = self.num_register_tokens
680
+ neuron_dict = self.neuron_dict
681
+
682
+ if self.input_patchnorm:
683
+ x = x.reshape(
684
+ x.shape[0],
685
+ x.shape[1],
686
+ self.grid_size[0],
687
+ self.patch_size[0],
688
+ self.grid_size[1],
689
+ self.patch_size[1],
690
+ )
691
+ x = x.permute(0, 2, 4, 1, 3, 5)
692
+ x = x.reshape(x.shape[0], self.grid_size[0] * self.grid_size[1], -1)
693
+ x = self.patchnorm_pre_ln(x)
694
+ x = self.conv1(x)
695
+ else:
696
+ x = self.conv1(x)
697
+ x = x.reshape(x.shape[0], x.shape[1], -1)
698
+ x = x.permute(0, 2, 1)
699
+
700
+ # class embeddings and positional embeddings
701
+ x = torch.cat([
702
+ self.class_embedding.to(x.dtype)
703
+ + torch.zeros(
704
+ x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
705
+ ),
706
+ x,
707
+ ],
708
+ dim=1,
709
+ )
710
+ x = x + self.positional_embedding.to(x.dtype)
711
+
712
+ extra_token_embeddings = []
713
+ total_patches = x.shape[1] - 1
714
+ for i in range(num_register_tokens):
715
+ extra_token_embeddings.append(
716
+ torch.zeros(
717
+ x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
718
+ ),
719
+ )
720
+
721
+ # Add extra tokens
722
+ if num_register_tokens > 0:
723
+ x = torch.cat([x, *extra_token_embeddings], dim=1)
724
+
725
+ x = self.patch_dropout(x)
726
+ x = self.ln_pre(x)
727
+
728
+ x = self.transformer(x, attn_mask=None, attn_method=attn_method, neuron_dict=neuron_dict, num_register_tokens=num_register_tokens)
729
+
730
+ if self.attn_pool is not None:
731
+ x = self.attn_pool(x)
732
+ x = self.ln_post(x)
733
+ pooled, tokens = self._global_pool(x)
734
+ else:
735
+ pooled, tokens = self._global_pool(x)
736
+ pooled = self.ln_post(pooled)
737
+
738
+ if self.proj is not None:
739
+ pooled = pooled @ self.proj
740
+
741
+ if self.output_tokens:
742
+ return pooled, tokens
743
+
744
+ return pooled
745
+
746
+
747
+ class TextTransformer(nn.Module):
748
+ output_tokens: torch.jit.Final[bool]
749
+
750
+ def __init__(
751
+ self,
752
+ context_length: int = 77,
753
+ vocab_size: int = 49408,
754
+ width: int = 512,
755
+ heads: int = 8,
756
+ layers: int = 12,
757
+ ls_init_value: float = None,
758
+ output_dim: int = 512,
759
+ act_layer: Callable = nn.GELU,
760
+ norm_layer: Callable = LayerNorm,
761
+ embed_cls: bool = False,
762
+ pad_id: int = 0,
763
+ output_tokens: bool = False,
764
+ ):
765
+ super().__init__()
766
+ self.output_tokens = output_tokens
767
+ self.num_pos = self.context_length = context_length
768
+ self.vocab_size = vocab_size
769
+ self.width = width
770
+ self.output_dim = output_dim
771
+ self.heads = heads
772
+ self.pad_id = pad_id
773
+
774
+ self.text_projection = nn.Parameter(torch.empty(width, output_dim))
775
+
776
+ if embed_cls:
777
+ self.cls_emb = nn.Parameter(torch.empty(width))
778
+ self.num_pos += 1
779
+ else:
780
+ self.cls_emb = None
781
+
782
+ self.token_embedding = nn.Embedding(vocab_size, width)
783
+ self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width))
784
+ self.transformer = Transformer(
785
+ width=width,
786
+ layers=layers,
787
+ heads=heads,
788
+ ls_init_value=ls_init_value,
789
+ act_layer=act_layer,
790
+ norm_layer=norm_layer,
791
+ )
792
+ self.ln_final = norm_layer(width)
793
+
794
+ self.register_buffer("attn_mask", self.build_attention_mask(), persistent=False)
795
+
796
+ self.init_parameters()
797
+
798
+ def init_parameters(self):
799
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
800
+ nn.init.normal_(self.positional_embedding, std=0.01)
801
+ if self.cls_emb is not None:
802
+ nn.init.normal_(self.cls_emb, std=0.01)
803
+
804
+ proj_std = (self.transformer.width**-0.5) * (
805
+ (2 * self.transformer.layers) ** -0.5
806
+ )
807
+ attn_std = self.transformer.width**-0.5
808
+ fc_std = (2 * self.transformer.width) ** -0.5
809
+ for block in self.transformer.resblocks:
810
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
811
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
812
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
813
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
814
+
815
+ if self.text_projection is not None:
816
+ nn.init.normal_(self.text_projection, std=self.transformer.width**-0.5)
817
+
818
+ @torch.jit.ignore
819
+ def set_grad_checkpointing(self, enable=True):
820
+ self.transformer.grad_checkpointing = enable
821
+
822
+ def build_attention_mask(self):
823
+ mask = torch.empty(self.num_pos, self.num_pos)
824
+ mask.fill_(float("-inf"))
825
+ mask.triu_(1)
826
+ return mask
827
+
828
+ def build_cls_mask(self, text, cast_dtype: torch.dtype):
829
+ cls_mask = (text != self.pad_id).unsqueeze(1)
830
+ cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=1.0)
831
+ additive_mask = torch.empty(
832
+ cls_mask.shape, dtype=cast_dtype, device=cls_mask.device
833
+ )
834
+ additive_mask.fill_(0)
835
+ additive_mask.masked_fill_(~cls_mask, float("-inf"))
836
+ additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0)
837
+ return additive_mask
838
+
839
+ def _repeat(self, t, N: int):
840
+ return t.reshape(1, 1, -1).repeat(N, 1, 1)
841
+
842
+ def forward(self, text, attn_method: Text = "direct"):
843
+ cast_dtype = self.transformer.get_cast_dtype()
844
+ seq_len = text.shape[1]
845
+
846
+ x = self.token_embedding(text).to(cast_dtype)
847
+ attn_mask = self.attn_mask
848
+ if self.cls_emb is not None:
849
+ seq_len += 1
850
+ x = torch.cat([x, self._repeat(self.cls_emb, x.shape[0])], dim=1)
851
+ cls_mask = self.build_cls_mask(text, cast_dtype)
852
+ attn_mask = (
853
+ attn_mask[None, :seq_len, :seq_len] + cls_mask[:, :seq_len, :seq_len]
854
+ )
855
+
856
+ x = x + self.positional_embedding[:seq_len].to(cast_dtype)
857
+ x = self.transformer(x, attn_mask=attn_mask, attn_method=attn_method)
858
+
859
+ if self.cls_emb is not None:
860
+ pooled, tokens = x[:, -1], x[:, :-1]
861
+ pooled = self.ln_final(pooled)
862
+ else:
863
+ x = self.ln_final(x)
864
+ pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x
865
+
866
+ if self.text_projection is not None:
867
+ pooled = pooled @ self.text_projection
868
+
869
+ if self.output_tokens:
870
+ return pooled, tokens
871
+
872
+ return pooled
utils.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import torch
3
+ from datetime import datetime
4
+ import os
5
+
6
+ class SaveOcassionally:
7
+ def __init__(self, out, every_sec = None, every_count = None):
8
+ assert every_sec != None or every_count != None
9
+
10
+ self.out = out
11
+ self.curr_time = time.time()
12
+ self.every_sec = every_sec
13
+ self.cnt = 0
14
+ self.every_count = every_count
15
+ self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
16
+
17
+ if "TIMESTAMP" in self.out:
18
+ self.out = self.out.replace("TIMESTAMP", self.timestamp)
19
+ print(f"Replacing TIMESTAMP with {self.timestamp}")
20
+
21
+ # Ensure the directory exists
22
+ out_dir = os.path.abspath(os.path.dirname(self.out))
23
+ if not os.path.exists(out_dir):
24
+ os.makedirs(out_dir)
25
+
26
+ def save(self, obj):
27
+ self.cnt += 1
28
+ if self.every_sec != None and time.time() - self.curr_time > self.every_sec:
29
+ torch.save(obj, self.out)
30
+ elif self.every_count != None and self.cnt % self.every_count == 0:
31
+ torch.save(obj, self.out)
32
+
33
+ def force_save(self, obj):
34
+ torch.save(obj, self.out)
utils/utils.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import torch
3
+ from datetime import datetime
4
+ import os
5
+
6
+ class SaveOcassionally:
7
+ def __init__(self, out, every_sec = None, every_count = None):
8
+ assert every_sec != None or every_count != None
9
+
10
+ self.out = out
11
+ self.curr_time = time.time()
12
+ self.every_sec = every_sec
13
+ self.cnt = 0
14
+ self.every_count = every_count
15
+ self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
16
+
17
+ if "TIMESTAMP" in self.out:
18
+ self.out = self.out.replace("TIMESTAMP", self.timestamp)
19
+ print(f"Replacing TIMESTAMP with {self.timestamp}")
20
+
21
+ # Ensure the directory exists
22
+ out_dir = os.path.abspath(os.path.dirname(self.out))
23
+ if not os.path.exists(out_dir):
24
+ os.makedirs(out_dir)
25
+
26
+ def save(self, obj):
27
+ self.cnt += 1
28
+ if self.every_sec != None and time.time() - self.curr_time > self.every_sec:
29
+ torch.save(obj, self.out)
30
+ elif self.every_count != None and self.cnt % self.every_count == 0:
31
+ torch.save(obj, self.out)
32
+
33
+ def force_save(self, obj):
34
+ torch.save(obj, self.out)
vitl14_attention.png ADDED

Git LFS Details

  • SHA256: 004298f80c7128c800b33c1f523eec8dc325350d44dab5de2067c160adadd1fe
  • Pointer size: 131 Bytes
  • Size of remote file: 287 kB
vitl14_patchnorms.png ADDED

Git LFS Details

  • SHA256: 4a14def773bd1264055007b2bbb6860452071a53f66100b51466ed77498ce1a6
  • Pointer size: 131 Bytes
  • Size of remote file: 269 kB
vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
vocab/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
zeroshot_classifier.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7dff47ac37ed4b67771bf6cf651a55dcf95d22eddc91acce2f54638ec82c6783
3
+ size 1537240