JH-C-k commited on
Commit
ec54512
ยท
verified ยท
1 Parent(s): 36a9b01

Add files using upload-large-folder tool

Browse files
__pycache__/model.cpython-310.pyc CHANGED
Binary files a/__pycache__/model.cpython-310.pyc and b/__pycache__/model.cpython-310.pyc differ
 
__pycache__/transformer.cpython-310.pyc CHANGED
Binary files a/__pycache__/transformer.cpython-310.pyc and b/__pycache__/transformer.cpython-310.pyc differ
 
config.json CHANGED
@@ -22,7 +22,123 @@
22
  "num_hidden_layers": 12,
23
  "max_position_embeddings": 77
24
  },
25
- "neuron_dict": {},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  "projection_dim": 768,
27
  "torch_dtype": "float32",
28
  "transformers_version": "4.21.0"
 
22
  "num_hidden_layers": 12,
23
  "max_position_embeddings": 77
24
  },
25
+ "num_register_tokens": 1,
26
+ "neuron_dict": {
27
+ "9": [
28
+ 815,
29
+ 4078,
30
+ 3618,
31
+ 2693,
32
+ 3973,
33
+ 1744,
34
+ 1983,
35
+ 1157,
36
+ 1309,
37
+ 1335,
38
+ 2607,
39
+ 2396,
40
+ 3049,
41
+ 1610,
42
+ 2621,
43
+ 2867,
44
+ 2012,
45
+ 1924,
46
+ 2394,
47
+ 3097,
48
+ 3125,
49
+ 3959,
50
+ 3210,
51
+ 2855,
52
+ 3609,
53
+ 526,
54
+ 3362,
55
+ 3395,
56
+ 2626,
57
+ 503,
58
+ 2941,
59
+ 3696,
60
+ 1823,
61
+ 2000,
62
+ 129,
63
+ 3667,
64
+ 1372,
65
+ 147,
66
+ 1150,
67
+ 852,
68
+ 3222
69
+ ],
70
+ "8": [
71
+ 745,
72
+ 3249,
73
+ 2585,
74
+ 1537,
75
+ 200,
76
+ 1603,
77
+ 1851,
78
+ 3523,
79
+ 3697,
80
+ 3137,
81
+ 2563,
82
+ 2293,
83
+ 730,
84
+ 906,
85
+ 1528,
86
+ 3348,
87
+ 2438,
88
+ 1564,
89
+ 1540,
90
+ 3238,
91
+ 3606
92
+ ],
93
+ "10": [
94
+ 357,
95
+ 1654,
96
+ 3940,
97
+ 2319,
98
+ 2560,
99
+ 2559,
100
+ 4009,
101
+ 3029,
102
+ 951,
103
+ 1903,
104
+ 738,
105
+ 1602,
106
+ 1807,
107
+ 2018,
108
+ 1281,
109
+ 267,
110
+ 3539,
111
+ 1015,
112
+ 496,
113
+ 693,
114
+ 2278,
115
+ 7,
116
+ 856,
117
+ 2785,
118
+ 2690,
119
+ 1367
120
+ ],
121
+ "7": [
122
+ 3228,
123
+ 2550,
124
+ 2977,
125
+ 3716,
126
+ 2467
127
+ ],
128
+ "0": [
129
+ 2890,
130
+ 1779,
131
+ 3761
132
+ ],
133
+ "6": [
134
+ 1042,
135
+ 2315,
136
+ 1674
137
+ ],
138
+ "3": [
139
+ 410
140
+ ]
141
+ },
142
  "projection_dim": 768,
143
  "torch_dtype": "float32",
144
  "transformers_version": "4.21.0"
model.py CHANGED
@@ -270,14 +270,23 @@ class CLIP(nn.Module):
270
  self.visual.set_grad_checkpointing(enable)
271
  self.transformer.grad_checkpointing = enable
272
 
273
- def encode_image(self, image, normalize: bool = False, attn_method: Text = 'direct', num_register_tokens = None, neuron_dict=None):
274
  if num_register_tokens is None and neuron_dict is None:
275
  num_register_tokens = self.num_register_tokens
276
  neuron_dict = self.neuron_dict
277
-
278
 
279
- features = self.visual(image, attn_method=attn_method, num_register_tokens=num_register_tokens, neuron_dict=neuron_dict)
280
- return F.normalize(features, dim=-1) if normalize else features
 
 
 
 
 
 
 
 
 
 
281
 
282
  def encode_text(self, text, normalize: bool = False):
283
  cast_dtype = self.transformer.get_cast_dtype()
 
270
  self.visual.set_grad_checkpointing(enable)
271
  self.transformer.grad_checkpointing = enable
272
 
273
+ def encode_image(self, image, normalize: bool = False, attn_method: Text = 'direct', num_register_tokens = None, neuron_dict=None, get_hidden_states=False):
274
  if num_register_tokens is None and neuron_dict is None:
275
  num_register_tokens = self.num_register_tokens
276
  neuron_dict = self.neuron_dict
 
277
 
278
+ if get_hidden_states:
279
+ ret = self.visual(image, attn_method=attn_method, num_register_tokens=num_register_tokens, neuron_dict=neuron_dict, get_hidden_states=get_hidden_states)
280
+ # warning only global cls token noramlized
281
+ return {
282
+ "pooled": F.normalize(ret["pooled"], dim=-1) if normalize else ret["pooled"],
283
+ "tokens": ret["tokens"],
284
+ "hidden_states": ret["hidden_states"]
285
+ }
286
+ else:
287
+ features = self.visual(image, attn_method=attn_method, num_register_tokens=num_register_tokens, neuron_dict=neuron_dict)
288
+ return F.normalize(features, dim=-1) if normalize else features
289
+
290
 
291
  def encode_text(self, text, normalize: bool = False):
292
  cast_dtype = self.transformer.get_cast_dtype()
model_sanity_check.ipynb CHANGED
@@ -13,7 +13,7 @@
13
  },
14
  {
15
  "cell_type": "code",
16
- "execution_count": 2,
17
  "id": "e7cec94e",
18
  "metadata": {},
19
  "outputs": [],
@@ -65,7 +65,7 @@
65
  },
66
  {
67
  "cell_type": "code",
68
- "execution_count": null,
69
  "id": "b4c7a750",
70
  "metadata": {},
71
  "outputs": [
@@ -77,6 +77,7 @@
77
  "โœ“ Added '/workspace/code/clipL336_TTR' to Python path.\n",
78
  "โœ“ Successfully imported 'model' from '/workspace/code/clipL336_TTR'\n",
79
  "Building vision tower with config: CLIPVisionCfg(layers=24, width=1024, head_width=64, mlp_ratio=4.0, patch_size=14, image_size=336, ls_init_value=None, patch_dropout=0.0, input_patchnorm=False, global_average_pool=False, attentional_pool=False, n_queries=256, attn_pooler_heads=8, output_tokens=False, timm_model_name=None, timm_model_pretrained=False, timm_pool='avg', timm_proj='linear', timm_proj_bias=False, timm_drop=0.0, timm_drop_path=None)\n",
 
80
  "โœ“ Added '/workspace/data/cache/huggingface/modules/transformers_modules/clipL336_TTR' to Python path.\n",
81
  "โœ“ Successfully imported 'tokenizer' from '/workspace/data/cache/huggingface/modules/transformers_modules/clipL336_TTR'\n",
82
  "Custom CLIP model loaded successfully!\n"
@@ -104,7 +105,7 @@
104
  },
105
  {
106
  "cell_type": "code",
107
- "execution_count": 4,
108
  "id": "ed3cbfdc",
109
  "metadata": {},
110
  "outputs": [
@@ -112,7 +113,7 @@
112
  "name": "stderr",
113
  "output_type": "stream",
114
  "text": [
115
- "100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1000/1000 [00:23<00:00, 41.71it/s]"
116
  ]
117
  },
118
  {
@@ -133,36 +134,29 @@
133
  "source": [
134
  "# langauge head\n",
135
  "### zeroshot head construction (text encoding) ###\n",
136
- "with torch.no_grad():\n",
137
- " zeroshot_weight = []\n",
138
- " for classname in tqdm(IMAGENET_CLASSNAMES):\n",
139
- " texts = [template(classname) for template in OPENAI_IMAGENET_TEMPLATES]\n",
140
- " text_inputs = preprocessor(text=texts, return_tensors=\"pt\", padding=\"max_length\").to(device)\n",
141
- " # text_inputs = model.tokenize(texts).to(device)\n",
142
- " # text_features = model.encode_text(text_inputs.input_ids)\n",
143
- " text_features = model_clip.get_text_features(**text_inputs)\n",
144
- " text_feature = F.normalize(text_features, dim=-1).mean(dim=0)\n",
145
- " # text_feature = text_features.mean(dim=0)\n",
146
- " text_feature = text_feature / text_feature.norm()\n",
147
- " zeroshot_weight.append(text_feature)\n",
148
- " \n",
149
- " text_features = torch.stack(zeroshot_weight, dim=1).to(device)\n",
150
- "print(\"Built text features:\", text_features.shape)"
151
- ]
152
- },
153
- {
154
- "cell_type": "code",
155
- "execution_count": 10,
156
- "id": "e1bd37d1",
157
- "metadata": {},
158
- "outputs": [],
159
- "source": [
160
- "torch.save(text_features, \"./zeroshot_classifier.pt\")"
161
  ]
162
  },
163
  {
164
  "cell_type": "code",
165
- "execution_count": 5,
166
  "id": "dbfeaedf",
167
  "metadata": {},
168
  "outputs": [],
@@ -173,7 +167,7 @@
173
  },
174
  {
175
  "cell_type": "code",
176
- "execution_count": 6,
177
  "id": "b0000195",
178
  "metadata": {},
179
  "outputs": [],
@@ -204,44 +198,170 @@
204
  " return top1 / n * 100, top5 / n * 100\n"
205
  ]
206
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  {
208
  "cell_type": "code",
209
  "execution_count": 8,
210
- "id": "8795b394",
211
  "metadata": {},
212
  "outputs": [
213
  {
214
- "name": "stderr",
215
- "output_type": "stream",
216
- "text": [
217
- "Evaluating: 0%| | 0/391 [00:00<?, ?batch/s]"
218
- ]
219
- },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  {
221
- "name": "stderr",
222
- "output_type": "stream",
223
- "text": [
224
- "Evaluating: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 391/391 [10:38<00:00, 1.63s/batch, samples=5e+4, top1=74.9, top5=94.4] "
225
- ]
226
- },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
  {
228
- "name": "stdout",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  "output_type": "stream",
230
  "text": [
231
- "Baseline (Topโ€‘1 / Topโ€‘5) on 50,000 imgs: 74.87% / 94.37%\n"
232
  ]
233
  },
234
  {
235
- "name": "stderr",
236
- "output_type": "stream",
237
- "text": [
238
- "\n"
 
 
 
 
 
239
  ]
240
  }
241
  ],
242
  "source": [
243
  "\n",
244
  "### baseline evaluator ###\n",
 
 
 
 
245
  "\n",
246
  "BASELINE_SAMPLES = 50000 # set to None for full 50โ€ฏk\n",
247
  "acc1, acc5 = evaluate(model, eval_loader, text_features, max_samples=BASELINE_SAMPLES)\n",
 
13
  },
14
  {
15
  "cell_type": "code",
16
+ "execution_count": 1,
17
  "id": "e7cec94e",
18
  "metadata": {},
19
  "outputs": [],
 
65
  },
66
  {
67
  "cell_type": "code",
68
+ "execution_count": 2,
69
  "id": "b4c7a750",
70
  "metadata": {},
71
  "outputs": [
 
77
  "โœ“ Added '/workspace/code/clipL336_TTR' to Python path.\n",
78
  "โœ“ Successfully imported 'model' from '/workspace/code/clipL336_TTR'\n",
79
  "Building vision tower with config: CLIPVisionCfg(layers=24, width=1024, head_width=64, mlp_ratio=4.0, patch_size=14, image_size=336, ls_init_value=None, patch_dropout=0.0, input_patchnorm=False, global_average_pool=False, attentional_pool=False, n_queries=256, attn_pooler_heads=8, output_tokens=False, timm_model_name=None, timm_model_pretrained=False, timm_pool='avg', timm_proj='linear', timm_proj_bias=False, timm_drop=0.0, timm_drop_path=None)\n",
80
+ "Currently text tower is removed, using only image encoder for feature extraction\n",
81
  "โœ“ Added '/workspace/data/cache/huggingface/modules/transformers_modules/clipL336_TTR' to Python path.\n",
82
  "โœ“ Successfully imported 'tokenizer' from '/workspace/data/cache/huggingface/modules/transformers_modules/clipL336_TTR'\n",
83
  "Custom CLIP model loaded successfully!\n"
 
105
  },
106
  {
107
  "cell_type": "code",
108
+ "execution_count": 3,
109
  "id": "ed3cbfdc",
110
  "metadata": {},
111
  "outputs": [
 
113
  "name": "stderr",
114
  "output_type": "stream",
115
  "text": [
116
+ "100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1000/1000 [00:23<00:00, 41.78it/s]"
117
  ]
118
  },
119
  {
 
134
  "source": [
135
  "# langauge head\n",
136
  "### zeroshot head construction (text encoding) ###\n",
137
+ "construction_language_cls_head = True\n",
138
+ "\n",
139
+ "if construction_language_cls_head:\n",
140
+ " with torch.no_grad():\n",
141
+ " zeroshot_weight = []\n",
142
+ " for classname in tqdm(IMAGENET_CLASSNAMES):\n",
143
+ " texts = [template(classname) for template in OPENAI_IMAGENET_TEMPLATES]\n",
144
+ " text_inputs = preprocessor(text=texts, return_tensors=\"pt\", padding=\"max_length\").to(device)\n",
145
+ " # text_inputs = model.tokenize(texts).to(device)\n",
146
+ " # text_features = model.encode_text(text_inputs.input_ids)\n",
147
+ " text_features = model_clip.get_text_features(**text_inputs)\n",
148
+ " text_feature = F.normalize(text_features, dim=-1).mean(dim=0)\n",
149
+ " # text_feature = text_features.mean(dim=0)\n",
150
+ " text_feature = text_feature / text_feature.norm()\n",
151
+ " zeroshot_weight.append(text_feature)\n",
152
+ " \n",
153
+ " text_features = torch.stack(zeroshot_weight, dim=1).to(device)\n",
154
+ " print(\"Built text features:\", text_features.shape)"
 
 
 
 
 
 
 
155
  ]
156
  },
157
  {
158
  "cell_type": "code",
159
+ "execution_count": 4,
160
  "id": "dbfeaedf",
161
  "metadata": {},
162
  "outputs": [],
 
167
  },
168
  {
169
  "cell_type": "code",
170
+ "execution_count": 5,
171
  "id": "b0000195",
172
  "metadata": {},
173
  "outputs": [],
 
198
  " return top1 / n * 100, top5 / n * 100\n"
199
  ]
200
  },
201
+ {
202
+ "cell_type": "code",
203
+ "execution_count": 6,
204
+ "id": "5806f422",
205
+ "metadata": {},
206
+ "outputs": [],
207
+ "source": [
208
+ "model = model.half().to(device)"
209
+ ]
210
+ },
211
+ {
212
+ "cell_type": "code",
213
+ "execution_count": 7,
214
+ "id": "21372f58",
215
+ "metadata": {},
216
+ "outputs": [],
217
+ "source": [
218
+ "sample_image = imagenet_dataset[0][0].unsqueeze(0).half().to(device)\n",
219
+ "result = model.encode_image(sample_image, get_hidden_states=True) # test\n",
220
+ "# ์ด๊ฑฐ ๊ทธ๋Œ€๋กœ ์ด์ œ foward"
221
+ ]
222
+ },
223
+ {
224
+ "cell_type": "code",
225
+ "execution_count": 16,
226
+ "id": "741d9cac",
227
+ "metadata": {},
228
+ "outputs": [],
229
+ "source": [
230
+ "vision_forward_result = model_clip.vision_model(pixel_values=sample_image, output_hidden_states=True)"
231
+ ]
232
+ },
233
+ {
234
+ "cell_type": "code",
235
+ "execution_count": 20,
236
+ "id": "c3dd134e",
237
+ "metadata": {},
238
+ "outputs": [
239
+ {
240
+ "data": {
241
+ "text/plain": [
242
+ "torch.Size([1, 577, 1024])"
243
+ ]
244
+ },
245
+ "execution_count": 20,
246
+ "metadata": {},
247
+ "output_type": "execute_result"
248
+ }
249
+ ],
250
+ "source": [
251
+ "vision_forward_result.hidden_states[0].shape\n",
252
+ "# ์—ฌ๊ธฐ์„œ๋„ 25์ด๋‹ค."
253
+ ]
254
+ },
255
  {
256
  "cell_type": "code",
257
  "execution_count": 8,
258
+ "id": "04d16694",
259
  "metadata": {},
260
  "outputs": [
261
  {
262
+ "data": {
263
+ "text/plain": [
264
+ "dict_keys(['pooled', 'tokens', 'hidden_states'])"
265
+ ]
266
+ },
267
+ "execution_count": 8,
268
+ "metadata": {},
269
+ "output_type": "execute_result"
270
+ }
271
+ ],
272
+ "source": [
273
+ "result.keys()\n",
274
+ "# ํ•˜๋‚˜๋งŒ ๋” ํ™•์ธํ•˜๊ธฐ CLS token์„ ๋ถ™์ด๋Š” ๊ฐ€ ์•„๋‹Œ๊ฐ€?"
275
+ ]
276
+ },
277
+ {
278
+ "cell_type": "code",
279
+ "execution_count": 9,
280
+ "id": "36c68c45",
281
+ "metadata": {},
282
+ "outputs": [
283
  {
284
+ "data": {
285
+ "text/plain": [
286
+ "torch.Size([1, 578, 1024])"
287
+ ]
288
+ },
289
+ "execution_count": 9,
290
+ "metadata": {},
291
+ "output_type": "execute_result"
292
+ }
293
+ ],
294
+ "source": [
295
+ "# ํ˜„ ์ƒํ™ฉ์„ ๋ณด๋ฉด register token์€ ์•ˆ ๋“ค์–ด๊ฐ€ ๊ฐ„ ๊ฒƒ์„ ๋ณผ ์ˆ˜ ์žˆ๋‹ค.\n",
296
+ "result[\"hidden_states\"][0].shape"
297
+ ]
298
+ },
299
+ {
300
+ "cell_type": "code",
301
+ "execution_count": null,
302
+ "id": "d9f06ec5",
303
+ "metadata": {},
304
+ "outputs": [
305
  {
306
+ "data": {
307
+ "text/plain": [
308
+ "1"
309
+ ]
310
+ },
311
+ "execution_count": 12,
312
+ "metadata": {},
313
+ "output_type": "execute_result"
314
+ }
315
+ ],
316
+ "source": [
317
+ "model.num_register_tokens\n",
318
+ "# ok, hidden state ๋„ฃ์–ด์ค„ ๋•Œ, layer_idx, ๊ทธ๋ฆฌ๊ณ  num_register ์ž˜ ์ธ์ง€ํ•ด์„œ, parsing์„ ํ•ด์ฃผ๋„๋ก ํ•ด์•ผ ๊ฒ ๋‹ค.\n",
319
+ "# model.neuron_dict"
320
+ ]
321
+ },
322
+ {
323
+ "cell_type": "code",
324
+ "execution_count": null,
325
+ "id": "4504ccd6",
326
+ "metadata": {},
327
+ "outputs": [],
328
+ "source": [
329
+ "raise StopIteration()"
330
+ ]
331
+ },
332
+ {
333
+ "cell_type": "code",
334
+ "execution_count": 10,
335
+ "id": "8795b394",
336
+ "metadata": {},
337
+ "outputs": [
338
+ {
339
+ "name": "stderr",
340
  "output_type": "stream",
341
  "text": [
342
+ "Evaluating: 2%|โ– | 6/391 [00:13<14:35, 2.27s/batch, samples=768, top1=89.1, top5=98.3]\n"
343
  ]
344
  },
345
  {
346
+ "ename": "KeyboardInterrupt",
347
+ "evalue": "",
348
+ "output_type": "error",
349
+ "traceback": [
350
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
351
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
352
+ "Cell \u001b[0;32mIn[10], line 8\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m### baseline evaluator ###\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;66;03m### ์ด๊ฑฐ๋Š” ์ง€๊ธˆ ๋‹น์žฅ์€ ๋ชป ์จ๋จน๋Š”๋‹ค... ๋ฏธ์นœ ๋„ˆ๋ฌด ๋А๋ฆฌ๋‹ค ์–ด๋””์„œ ๋ฌธ์ œ์ง€ ###\u001b[39;00m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;66;03m# ์”จ๋ฐœ ์ด๋ฒˆ์— ๋ญ์ง€\u001b[39;00m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;66;03m# architecture define์ด ์–ด๋”˜๊ฐ€์—์„œ ์†์ƒ ๋œ ๊ฒƒ์œผ๋กœ ๋ณด์ธ๋‹ค\u001b[39;00m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;66;03m# ์„ฑ๋Šฅ reproduce...\u001b[39;00m\n\u001b[1;32m 7\u001b[0m BASELINE_SAMPLES \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m50000\u001b[39m \u001b[38;5;66;03m# set to None for full 50โ€ฏk\u001b[39;00m\n\u001b[0;32m----> 8\u001b[0m acc1, acc5 \u001b[38;5;241m=\u001b[39m \u001b[43mevaluate\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43meval_loader\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtext_features\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmax_samples\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mBASELINE_SAMPLES\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 9\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mBaseline (Topโ€‘1 / Topโ€‘5) on \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mBASELINE_SAMPLES\u001b[38;5;250m \u001b[39m\u001b[38;5;129;01mor\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28mlen\u001b[39m(imagenet_dataset)\u001b[38;5;132;01m:\u001b[39;00m\u001b[38;5;124m,\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m imgs: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00macc1\u001b[38;5;132;01m:\u001b[39;00m\u001b[38;5;124m.2f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m% / \u001b[39m\u001b[38;5;132;01m{\u001b[39;00macc5\u001b[38;5;132;01m:\u001b[39;00m\u001b[38;5;124m.2f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m%\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
353
+ "Cell \u001b[0;32mIn[5], line 20\u001b[0m, in \u001b[0;36mevaluate\u001b[0;34m(model, loader, text_feats, max_samples)\u001b[0m\n\u001b[1;32m 18\u001b[0m logits \u001b[38;5;241m=\u001b[39m model\u001b[38;5;241m.\u001b[39mmodel\u001b[38;5;241m.\u001b[39mlogit_scale\u001b[38;5;241m.\u001b[39mexp() \u001b[38;5;241m*\u001b[39m feats \u001b[38;5;241m@\u001b[39m text_feats \n\u001b[1;32m 19\u001b[0m _, pred \u001b[38;5;241m=\u001b[39m logits\u001b[38;5;241m.\u001b[39mtopk(\u001b[38;5;241m5\u001b[39m, dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[0;32m---> 20\u001b[0m top1 \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[43m(\u001b[49m\u001b[43mpred\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m:\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m==\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mlabels\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43munsqueeze\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msum\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mitem\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 21\u001b[0m top5 \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m (pred \u001b[38;5;241m==\u001b[39m labels\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m1\u001b[39m))\u001b[38;5;241m.\u001b[39msum()\u001b[38;5;241m.\u001b[39mitem()\n\u001b[1;32m 22\u001b[0m n \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m images\u001b[38;5;241m.\u001b[39msize(\u001b[38;5;241m0\u001b[39m)\n",
354
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
355
  ]
356
  }
357
  ],
358
  "source": [
359
  "\n",
360
  "### baseline evaluator ###\n",
361
+ "### ์ด๊ฑฐ๋Š” ์ง€๊ธˆ ๋‹น์žฅ์€ ๋ชป ์จ๋จน๋Š”๋‹ค... ๋ฏธ์นœ ๋„ˆ๋ฌด ๋А๋ฆฌ๋‹ค ์–ด๋””์„œ ๋ฌธ์ œ์ง€ ###\n",
362
+ "# ์”จ๋ฐœ ์ด๋ฒˆ์— ๋ญ์ง€\n",
363
+ "# architecture define์ด ์–ด๋”˜๊ฐ€์—์„œ ์†์ƒ ๋œ ๊ฒƒ์œผ๋กœ ๋ณด์ธ๋‹ค\n",
364
+ "# ์„ฑ๋Šฅ reproduce...\n",
365
  "\n",
366
  "BASELINE_SAMPLES = 50000 # set to None for full 50โ€ฏk\n",
367
  "acc1, acc5 = evaluate(model, eval_loader, text_features, max_samples=BASELINE_SAMPLES)\n",
transformer.py CHANGED
@@ -560,8 +560,12 @@ class Transformer(nn.Module):
560
  attn_mask: Optional[torch.Tensor] = None,
561
  attn_method: Text = "direct",
562
  neuron_dict=None,
563
- num_register_tokens=0
564
- ):
 
 
 
 
565
  for r in self.resblocks:
566
  if self.grad_checkpointing and not torch.jit.is_scripting():
567
  raise ValueError("grad_checkpointing not implemented")
@@ -573,7 +577,16 @@ class Transformer(nn.Module):
573
  neuron_dict=neuron_dict,
574
  num_register_tokens=num_register_tokens
575
  )
576
- return x
 
 
 
 
 
 
 
 
 
577
 
578
 
579
  class VisionTransformer(nn.Module):
@@ -672,7 +685,8 @@ class VisionTransformer(nn.Module):
672
  else:
673
  return x[:, 0], x[:, 1:]
674
 
675
- def forward(self, x: torch.Tensor, attn_method: Text = "direct", num_register_tokens = None, neuron_dict=None):
 
676
  # to patches
677
 
678
  if num_register_tokens is None and neuron_dict is None:
@@ -725,7 +739,14 @@ class VisionTransformer(nn.Module):
725
  x = self.patch_dropout(x)
726
  x = self.ln_pre(x)
727
 
728
- x = self.transformer(x, attn_mask=None, attn_method=attn_method, neuron_dict=neuron_dict, num_register_tokens=num_register_tokens)
 
 
 
 
 
 
 
729
 
730
  if self.attn_pool is not None:
731
  x = self.attn_pool(x)
@@ -740,6 +761,13 @@ class VisionTransformer(nn.Module):
740
 
741
  if self.output_tokens:
742
  return pooled, tokens
 
 
 
 
 
 
 
743
 
744
  return pooled
745
 
 
560
  attn_mask: Optional[torch.Tensor] = None,
561
  attn_method: Text = "direct",
562
  neuron_dict=None,
563
+ num_register_tokens=0,
564
+ get_hidden_states: bool = False,
565
+ ):
566
+ if "hidden_states" not in locals():
567
+ hidden_states = []
568
+ hidden_states.append(x) # input embedding ์ €์žฅ
569
  for r in self.resblocks:
570
  if self.grad_checkpointing and not torch.jit.is_scripting():
571
  raise ValueError("grad_checkpointing not implemented")
 
577
  neuron_dict=neuron_dict,
578
  num_register_tokens=num_register_tokens
579
  )
580
+ if get_hidden_states:
581
+ hidden_states.append(x)
582
+
583
+ if get_hidden_states:
584
+ return {
585
+ "hidden_states": hidden_states,
586
+ "last_hidden_state": x,
587
+ }
588
+ else:
589
+ return x
590
 
591
 
592
  class VisionTransformer(nn.Module):
 
685
  else:
686
  return x[:, 0], x[:, 1:]
687
 
688
+ # ์—ฌ๊ธฐ์„œ
689
+ def forward(self, x: torch.Tensor, attn_method: Text = "direct", num_register_tokens = None, neuron_dict=None, get_hidden_states:bool=False):
690
  # to patches
691
 
692
  if num_register_tokens is None and neuron_dict is None:
 
739
  x = self.patch_dropout(x)
740
  x = self.ln_pre(x)
741
 
742
+ # ์—ฌ๊ธฐ์„œ ๋“ค์–ด๊ฐ€๋Š” ๊ฒƒ์€ [B, 1+ 576 + num_register_tokens, C]
743
+ if get_hidden_states:
744
+ ret = self.transformer(x, attn_mask=None, attn_method=attn_method, neuron_dict=neuron_dict, num_register_tokens=num_register_tokens,get_hidden_states=get_hidden_states)
745
+
746
+ hidden_states = ret["hidden_states"]
747
+ x = ret["last_hidden_state"]
748
+ else:
749
+ x = self.transformer(x, attn_mask=None, attn_method=attn_method, neuron_dict=neuron_dict, num_register_tokens=num_register_tokens,get_hidden_states=get_hidden_states)
750
 
751
  if self.attn_pool is not None:
752
  x = self.attn_pool(x)
 
761
 
762
  if self.output_tokens:
763
  return pooled, tokens
764
+
765
+ if get_hidden_states:
766
+ return {
767
+ "pooled": pooled, # GLOBAL CLS
768
+ "tokens": tokens, # ALL TOKENS
769
+ "hidden_states": hidden_states # layer-wise hidden states
770
+ }
771
 
772
  return pooled
773
 
zeroshot_classifier.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7dff47ac37ed4b67771bf6cf651a55dcf95d22eddc91acce2f54638ec82c6783
3
- size 1537240
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:629c9c73b717ffa38a56f57b20ebe4fd5470cc03d730f7919c2bacf2c388f560
3
+ size 124120