ynhe commited on
Commit
ff495b4
·
verified ·
1 Parent(s): 7fc7e58

[Init] upload model

Browse files
.gitattributes CHANGED
@@ -1,35 +1 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  *.safetensors filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
config.json ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "InternVideo2_CLIP_small"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "config.InternVideo2Config",
7
+ "AutoModel": "modeling_internvideo2encoder.InternVideo2_CLIP_small"
8
+ },
9
+ "auto_resume": false,
10
+ "batch_size": 64,
11
+ "batch_size_test": 4,
12
+ "best_key": [
13
+ "msrvtt_1k_test_match",
14
+ "t2v_r1"
15
+ ],
16
+ "compile_model": false,
17
+ "criterion": {
18
+ "clip_loss_ratio": [
19
+ 1.0,
20
+ 1.0
21
+ ],
22
+ "distill_final_features": true,
23
+ "loss_weight": {
24
+ "mlm": 1.0,
25
+ "mvm": 0.0,
26
+ "uta": 0.0,
27
+ "vtc": 1.0,
28
+ "vtm": 1.0
29
+ },
30
+ "mlm_masking_prob": 0.5,
31
+ "vtm_hard_neg": true
32
+ },
33
+ "debug": false,
34
+ "deep_fusion": false,
35
+ "deepspeed": {
36
+ "enable": true,
37
+ "stage": 1
38
+ },
39
+ "delete_ds_optim_states": true,
40
+ "device": "cuda",
41
+ "dist_url": "env://",
42
+ "evaluate": false,
43
+ "evaluation": {
44
+ "eval_frame_ensemble": "concat",
45
+ "eval_offload": true,
46
+ "eval_x_only": false,
47
+ "k_test": 128
48
+ },
49
+ "gradient_checkpointing": true,
50
+ "inputs": {
51
+ "batch_size": {
52
+ "image": 64,
53
+ "video": 64
54
+ },
55
+ "batch_size_test": {
56
+ "image": 4,
57
+ "video": 4
58
+ },
59
+ "image_res": 224,
60
+ "max_txt_l": {
61
+ "image": 32,
62
+ "video": 32
63
+ },
64
+ "video_input": {
65
+ "num_frames": 8,
66
+ "num_frames_test": 8,
67
+ "random_aug": false,
68
+ "sample_type": "middle",
69
+ "sample_type_test": "middle"
70
+ }
71
+ },
72
+ "jump_evaluate": false,
73
+ "log_freq": 100,
74
+ "max_txt_l": 32,
75
+ "mode": "pt",
76
+ "model": {
77
+ "embed_dim": 1024,
78
+ "find_unused_parameters": false,
79
+ "freeze_text": true,
80
+ "freeze_vision": true,
81
+ "load_vision_ckpt_from_internvideo2_stage2": false,
82
+ "model_cls": "InternVideo2_CLIP_small",
83
+ "multimodal": {
84
+ "enable": true
85
+ },
86
+ "open_text_projection": false,
87
+ "open_vision_clip_projector": true,
88
+ "temp": 0.01,
89
+ "temp_min": 0.01,
90
+ "text_encoder": {
91
+ "embed_dim": 512,
92
+ "image_cfg": {
93
+ "image_size": 224,
94
+ "model_name": "vit_b16"
95
+ },
96
+ "text_cfg": {
97
+ "causal_masking": true,
98
+ "context_length": 77,
99
+ "dim": 512,
100
+ "ffn_multiplier_per_layer": 4.0,
101
+ "model_name": "base",
102
+ "n_heads_per_layer": 8,
103
+ "n_transformer_layers": 12,
104
+ "norm_layer": "layer_norm_fp32",
105
+ "vocab_size": 49408
106
+ }
107
+ },
108
+ "vision_encoder": {
109
+ "align_dim": 512,
110
+ "attn_pool_num_heads": 16,
111
+ "checkpoint_num": 0,
112
+ "clip_embed_dim": 768,
113
+ "depth": 24,
114
+ "drop_cls_token": false,
115
+ "drop_path_rate": 0.0,
116
+ "embed_dim": 1024,
117
+ "fused_mlp_heuristic": 1,
118
+ "head_drop_path_rate": 0.0,
119
+ "img_size": 224,
120
+ "in_chans": 3,
121
+ "init_values": 0.1,
122
+ "layerscale_no_force_fp32": true,
123
+ "mlp_ratio": 4,
124
+ "name": "internvideo2_1B",
125
+ "num_frames": 8,
126
+ "num_heads": 16,
127
+ "patch_size": 14,
128
+ "qk_normalization": true,
129
+ "qkv_bias": false,
130
+ "sep_pos_embed": false,
131
+ "tubelet_size": 1,
132
+ "use_checkpoint": false,
133
+ "use_flash_attn": false,
134
+ "use_fused_mlp": false,
135
+ "use_fused_rmsnorm": false
136
+ }
137
+ },
138
+ "model_type": "internvideo2",
139
+ "num_frames": 8,
140
+ "num_frames_test": 8,
141
+ "num_workers": 6,
142
+ "optimizer": {
143
+ "different_lr": {
144
+ "enable": false,
145
+ "lr": 0.001,
146
+ "module_names": []
147
+ },
148
+ "lr": 5e-05,
149
+ "max_grad_norm": 3.0,
150
+ "opt": "adamW",
151
+ "opt_betas": [
152
+ 0.9,
153
+ 0.98
154
+ ],
155
+ "weight_decay": 0.05
156
+ },
157
+ "output_dir": null,
158
+ "pretrained_path": "",
159
+ "resume": false,
160
+ "save_ckpt_iter": null,
161
+ "save_latest": true,
162
+ "scheduler": {
163
+ "epochs": 10,
164
+ "min_lr_multi": 0.01,
165
+ "sched": "cosine",
166
+ "warmup_epochs": 1
167
+ },
168
+ "seed": 42,
169
+ "test_file": {
170
+ "didemo_ret_test": "available_corpus[\"didemo_ret_test\"]",
171
+ "msrvtt_1k_test": "available_corpus[\"msrvtt_1k_test\"]"
172
+ },
173
+ "test_types": [
174
+ "msrvtt_1k_test",
175
+ "didemo_ret_test"
176
+ ],
177
+ "text_enc": "bert_large",
178
+ "tokenizer": null,
179
+ "torch_dtype": "float16",
180
+ "train_file": "available_corpus[\"pretrain_example_data_1B\"]",
181
+ "transformers_version": "4.51.3",
182
+ "use_bf16": true,
183
+ "use_flash_sdp": false,
184
+ "use_half_precision": false,
185
+ "use_mem_efficient_sdp": false,
186
+ "wandb": {
187
+ "enable": false,
188
+ "entity": "opengvlab",
189
+ "project": "InternVideo2-Stage2"
190
+ }
191
+ }
config.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig, PreTrainedModel, AutoModel, AutoConfig
2
+
3
+ class EasyDict(dict):
4
+ def __init__(self, d=None, **kwargs):
5
+ if d is None:
6
+ d = {}
7
+ if kwargs:
8
+ d.update(**kwargs)
9
+ for k, v in d.items():
10
+ setattr(self, k, v)
11
+ # Class attributes
12
+ for k in self.__class__.__dict__.keys():
13
+ if not (k.startswith("__") and k.endswith("__")) and not k in ("update", "pop"):
14
+ setattr(self, k, getattr(self, k))
15
+
16
+ def __setattr__(self, name, value):
17
+ if isinstance(value, (list, tuple)):
18
+ value = [self.__class__(x) if isinstance(x, dict) else x for x in value]
19
+ elif isinstance(value, dict) and not isinstance(value, self.__class__):
20
+ value = self.__class__(value)
21
+ super(EasyDict, self).__setattr__(name, value)
22
+ super(EasyDict, self).__setitem__(name, value)
23
+
24
+ __setitem__ = __setattr__
25
+
26
+ def update(self, e=None, **f):
27
+ d = e or dict()
28
+ d.update(f)
29
+ for k in d:
30
+ setattr(self, k, d[k])
31
+
32
+ def pop(self, k, d=None):
33
+ if hasattr(self, k):
34
+ delattr(self, k)
35
+ return super(EasyDict, self).pop(k, d)
36
+
37
+ class InternVideo2Config(PretrainedConfig):
38
+ model_type = "internvideo2"
39
+
40
+ def __init__(self,
41
+ tokenizer=None,
42
+ train_file=None,
43
+ test_file=None,
44
+ test_types=None,
45
+ num_workers=6,
46
+ best_key=None,
47
+ num_frames=8,
48
+ num_frames_test=8,
49
+ batch_size=64,
50
+ batch_size_test=4,
51
+ max_txt_l=32,
52
+ inputs=None,
53
+ text_enc="bert_large",
54
+ model=None,
55
+ criterion=None,
56
+ optimizer=None,
57
+ scheduler=None,
58
+ evaluate=False,
59
+ deep_fusion=False,
60
+ evaluation=None,
61
+ use_half_precision=False,
62
+ use_bf16=True,
63
+ gradient_checkpointing=True,
64
+ use_flash_sdp=False,
65
+ use_mem_efficient_sdp=False,
66
+ compile_model=False,
67
+ wandb=None,
68
+ dist_url="env://",
69
+ device="cuda",
70
+ mode="pt",
71
+ output_dir=None,
72
+ resume=False,
73
+ debug=False,
74
+ log_freq=100,
75
+ seed=42,
76
+ save_latest=True,
77
+ auto_resume=False,
78
+ jump_evaluate=False,
79
+ pretrained_path="",
80
+ save_ckpt_iter=None,
81
+ delete_ds_optim_states=True,
82
+ deepspeed=None,
83
+ **kwargs):
84
+ super().__init__(**kwargs)
85
+
86
+ self.tokenizer = tokenizer
87
+
88
+ # Data configuration
89
+ self.train_file = train_file or "available_corpus[\"pretrain_example_data_1B\"]"
90
+ self.test_file = EasyDict(test_file or {
91
+ "msrvtt_1k_test": "available_corpus[\"msrvtt_1k_test\"]",
92
+ "didemo_ret_test": "available_corpus[\"didemo_ret_test\"]"
93
+ })
94
+ self.test_types = test_types or ["msrvtt_1k_test", "didemo_ret_test"]
95
+ self.num_workers = num_workers
96
+ self.best_key = best_key or ["msrvtt_1k_test_match", "t2v_r1"]
97
+
98
+ # Input configuration
99
+ self.num_frames = num_frames
100
+ self.num_frames_test = num_frames_test
101
+ self.batch_size = batch_size
102
+ self.batch_size_test = batch_size_test
103
+ self.max_txt_l = max_txt_l
104
+ self.inputs = EasyDict(inputs or {
105
+ "image_res": 224,
106
+ "video_input": EasyDict({
107
+ "num_frames": num_frames,
108
+ "sample_type": "rand",
109
+ "num_frames_test": num_frames_test,
110
+ "sample_type_test": "middle",
111
+ "random_aug": False
112
+ }),
113
+ "max_txt_l": EasyDict({"image": max_txt_l, "video": max_txt_l}),
114
+ "batch_size": EasyDict({"image": batch_size, "video": batch_size}),
115
+ "batch_size_test": EasyDict({"image": batch_size_test, "video": batch_size_test})
116
+ })
117
+
118
+ # Model configuration
119
+ self.text_enc = text_enc
120
+ self.model = EasyDict(model or {
121
+ "model_cls": "InternVideo2_Stage2",
122
+ "vision_encoder": EasyDict({
123
+ "name": "pretrain_internvideo2_1b_patch14_224",
124
+ "img_size": 224,
125
+ "num_frames": num_frames,
126
+ "tubelet_size": 1,
127
+ "patch_size": 14,
128
+ "d_model": 1408,
129
+ "clip_embed_dim": 768,
130
+ "clip_teacher_embed_dim": 3200,
131
+ "clip_teacher_final_dim": 768,
132
+ "clip_norm_type": "l2",
133
+ "clip_return_layer": 6,
134
+ "clip_student_return_interval": 1,
135
+ "pretrained": None,
136
+ "use_checkpoint": False,
137
+ "checkpoint_num": 40,
138
+ "use_flash_attn": True,
139
+ "use_fused_rmsnorm": True,
140
+ "use_fused_mlp": True,
141
+ "clip_teacher": None,
142
+ "clip_input_resolution": 224,
143
+ "clip_teacher_return_interval": 1,
144
+ "video_mask_type": "random",
145
+ "video_mask_ratio": 0.8,
146
+ "image_mask_type": "random",
147
+ "image_mask_ratio": 0.5,
148
+ "sep_image_video_pos_embed": True,
149
+ "keep_temporal": False,
150
+ "only_mask": True
151
+ }),
152
+ "text_encoder": text_enc,
153
+ "multimodal": EasyDict({"enable": True}),
154
+ "embed_dim": 512,
155
+ "temp": 0.07,
156
+ "find_unused_parameters": False
157
+ })
158
+
159
+ # Criterion configuration
160
+ self.criterion = EasyDict(criterion or {
161
+ "loss_weight": EasyDict({
162
+ "vtc": 1.0,
163
+ "mlm": 1.0,
164
+ "vtm": 1.0,
165
+ "mvm": 0.0,
166
+ "uta": 0.0
167
+ }),
168
+ "vtm_hard_neg": True,
169
+ "mlm_masking_prob": 0.5,
170
+ "distill_final_features": True,
171
+ "clip_loss_ratio": [1.0, 1.0]
172
+ })
173
+
174
+ # Optimizer configuration
175
+ self.optimizer = EasyDict(optimizer or {
176
+ "opt": "adamW",
177
+ "lr": 5e-5,
178
+ "opt_betas": [0.9, 0.98],
179
+ "weight_decay": 0.05,
180
+ "max_grad_norm": 3.0,
181
+ "different_lr": EasyDict({"enable": False, "module_names": [], "lr": 1e-3})
182
+ })
183
+
184
+ # Scheduler configuration
185
+ self.scheduler = EasyDict(scheduler or {
186
+ "sched": "cosine",
187
+ "epochs": 10,
188
+ "min_lr_multi": 0.01,
189
+ "warmup_epochs": 1
190
+ })
191
+
192
+ # Evaluation configuration
193
+ self.evaluate = evaluate
194
+ self.deep_fusion = deep_fusion
195
+ self.evaluation = EasyDict(evaluation or {
196
+ "eval_frame_ensemble": "concat",
197
+ "eval_x_only": False,
198
+ "k_test": 128,
199
+ "eval_offload": True
200
+ })
201
+
202
+ # Miscellaneous
203
+ self.use_half_precision = use_half_precision
204
+ self.use_bf16 = use_bf16
205
+ self.gradient_checkpointing = gradient_checkpointing
206
+ self.use_flash_sdp = use_flash_sdp
207
+ self.use_mem_efficient_sdp = use_mem_efficient_sdp
208
+ self.compile_model = compile_model
209
+
210
+ self.wandb = EasyDict(wandb or {
211
+ "enable": False,
212
+ "entity": "opengvlab",
213
+ "project": "InternVideo2-Stage2"
214
+ })
215
+
216
+ self.dist_url = dist_url
217
+ self.device = device
218
+ self.mode = mode
219
+ self.output_dir = output_dir
220
+ self.resume = resume
221
+ self.debug = debug
222
+ self.log_freq = log_freq
223
+ self.seed = seed
224
+
225
+ self.save_latest = save_latest
226
+ self.auto_resume = auto_resume
227
+ self.jump_evaluate = jump_evaluate
228
+ self.pretrained_path = pretrained_path
229
+ self.save_ckpt_iter = save_ckpt_iter
230
+ self.delete_ds_optim_states = delete_ds_optim_states
231
+
232
+ self.deepspeed = EasyDict(deepspeed or {
233
+ "enable": True,
234
+ "stage": 1
235
+ })
236
+ def set_num_frames(self, num_frames):
237
+ # print('Here ', num_frames)
238
+ self.num_frames = num_frames
239
+ self.inputs.video_input.num_frames = num_frames
240
+ self.model.vision_encoder.num_frames = num_frames
demo.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import io
4
+ import av
5
+ import cv2
6
+ import decord
7
+ import imageio
8
+ from decord import VideoReader
9
+ import torch
10
+ import numpy as np
11
+ import math
12
+ import torch.nn.functional as F
13
+ decord.bridge.set_bridge("torch")
14
+
15
+ from transformers import AutoConfig, AutoModel
16
+ config = AutoConfig.from_pretrained("/fs-computility/video/heyinan/iv2hf/", trust_remote_code=True)
17
+ model = AutoModel.from_pretrained("/fs-computility/video/heyinan/iv2hf/", trust_remote_code=True).to(config.device)
18
+
19
+
20
+ def get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1, start=None, end=None):
21
+ start_frame, end_frame = 0, vlen
22
+ if start is not None:
23
+ start_frame = max(start_frame,int(start * input_fps))
24
+ if end is not None:
25
+ end_frame = min(end_frame,int(end * input_fps))
26
+
27
+ # Ensure start_frame is less than end_frame
28
+ if start_frame >= end_frame:
29
+ raise ValueError("Start frame index must be less than end frame index")
30
+
31
+ # Calculate the length of the clip in frames
32
+ clip_length = end_frame - start_frame
33
+
34
+ if sample in ["rand", "middle"]: # uniform sampling
35
+ acc_samples = min(num_frames, clip_length)
36
+ # split the clip into `acc_samples` intervals, and sample from each interval.
37
+ intervals = np.linspace(start=start_frame, stop=end_frame, num=acc_samples + 1).astype(int)
38
+ ranges = []
39
+ for idx, interv in enumerate(intervals[:-1]):
40
+ ranges.append((interv, intervals[idx + 1] - 1))
41
+ if sample == 'rand':
42
+ try:
43
+ frame_indices = [random.choice(range(x[0], x[1] + 1)) for x in ranges]
44
+ except:
45
+ frame_indices = np.random.permutation(clip_length)[:acc_samples] + start_frame
46
+ frame_indices.sort()
47
+ frame_indices = list(frame_indices)
48
+ elif fix_start is not None:
49
+ frame_indices = [x[0] + fix_start for x in ranges]
50
+ elif sample == 'middle':
51
+ frame_indices = [(x[0] + x[1]) // 2 for x in ranges]
52
+ else:
53
+ raise NotImplementedError
54
+
55
+ if len(frame_indices) < num_frames: # padded with last frame
56
+ padded_frame_indices = [frame_indices[-1]] * num_frames
57
+ padded_frame_indices[:len(frame_indices)] = frame_indices
58
+ frame_indices = padded_frame_indices
59
+ elif "fps" in sample: # fps0.5, sequentially sample frames at 0.5 fps
60
+ output_fps = float(sample[3:])
61
+ duration = float(clip_length) / input_fps
62
+ delta = 1 / output_fps # gap between frames, this is also the clip length each frame represents
63
+ frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta)
64
+ frame_indices = np.around(frame_seconds * input_fps).astype(int) + start_frame
65
+ frame_indices = [e for e in frame_indices if e < end_frame]
66
+ if max_num_frames > 0 and len(frame_indices) > max_num_frames:
67
+ frame_indices = frame_indices[:max_num_frames]
68
+ # frame_indices = np.linspace(0 + delta / 2, duration + delta / 2, endpoint=False, num=max_num_frames)
69
+ else:
70
+ raise ValueError
71
+ return frame_indices
72
+
73
+ def read_frames_decord(
74
+ video_path, num_frames, sample='middle', fix_start=None,
75
+ max_num_frames=-1, client=None, trimmed30=False, start=None, end=None
76
+ ):
77
+ num_threads = 1 if video_path.endswith('.webm') else 0 # make ssv2 happy
78
+
79
+ video_reader = VideoReader(video_path, num_threads=num_threads)
80
+ vlen = len(video_reader)
81
+
82
+ fps = video_reader.get_avg_fps()
83
+ duration = vlen / float(fps)
84
+
85
+ frame_indices = get_frame_indices(
86
+ num_frames, vlen, sample=sample, fix_start=fix_start,
87
+ input_fps=fps, max_num_frames=max_num_frames, start=start, end=end
88
+ )
89
+
90
+ frames = video_reader.get_batch(frame_indices) # (T, H, W, C), torch.uint8
91
+ frames = frames.permute(0, 3, 1, 2) # (T, C, H, W), torch.uint8
92
+ return frames, frame_indices, duration
93
+
94
+ def get_text_feature(model, texts):
95
+ text_input = model.tokenizer(texts).to(model.device)
96
+ text_features = model.encode_text(text_input)
97
+ return text_features
98
+
99
+ def get_similarity(video_feature, text_feature):
100
+ video_feature = F.normalize(video_feature, dim=-1)
101
+ text_feature = F.normalize(text_feature, dim=-1)
102
+ sim_matrix = text_feature @ video_feature.T
103
+ return sim_matrix
104
+
105
+ def get_top_videos(model, text_features, video_features, video_paths, texts):
106
+ # text_features = get_text_feature(texts)
107
+
108
+ video_features = F.normalize(video_features, dim=-1)
109
+ text_features = F.normalize(text_features, dim=-1)
110
+
111
+ # print(text_features.shape, video_features.shape)
112
+ sim_matrix = text_features @ video_features.T
113
+ # print(sim_matrix.shape)
114
+
115
+ top_k = 5
116
+ sim_matrix_top_k = torch.topk(sim_matrix, top_k, dim=1)[1]
117
+ softmax_sim_matrix = F.softmax(sim_matrix, dim=1)
118
+
119
+ retrieval_infos = {}
120
+ for i in range(len(sim_matrix_top_k)):
121
+ print("\n",texts[i])
122
+ retrieval_infos[texts[i]] = []
123
+ for j in range(top_k):
124
+ print("top", j+1, ":", video_paths[sim_matrix_top_k[i][j]], "~prob:", sim_matrix[i][sim_matrix_top_k[i][j]].item())
125
+ retrieval_infos[texts[i]].append({"video": video_paths[sim_matrix_top_k[i][j]], "prob": sim_matrix[i][sim_matrix_top_k[i][j]].item(), "rank": j+1})
126
+ return retrieval_infos
127
+
128
+ if __name__=="__main__":
129
+ video_features = []
130
+ demo_videos = ["video1.mp4","video2.mp4"]
131
+ texts = ['a person talking', 'a logo', 'a building']
132
+ for video_path in demo_videos:
133
+ frames, frame_indices, video_duration = read_frames_decord(video_path,8)
134
+ frames = model.transform(frames).unsqueeze(0).to(model.device)
135
+ with torch.no_grad():
136
+ video_feature = model.encode_vision(frames, test=True)
137
+ video_features.append(video_feature)
138
+
139
+ text_features = get_text_feature(model, texts)
140
+ video_features = torch.cat(video_features, dim=0).to(text_features.dtype).to(config.device)
141
+ results = get_top_videos(model, text_features, video_features, demo_videos, texts)
142
+
143
+
flash_attention_class.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from einops import rearrange
5
+
6
+ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
7
+ from flash_attn.bert_padding import unpad_input, pad_input
8
+
9
+
10
+ class FlashAttention(nn.Module):
11
+ """Implement the scaled dot product attention with softmax.
12
+ Arguments
13
+ ---------
14
+ softmax_scale: The temperature to use for the softmax attention.
15
+ (default: 1/sqrt(d_keys) where d_keys is computed at
16
+ runtime)
17
+ attention_dropout: The dropout rate to apply to the attention
18
+ (default: 0.0)
19
+ """
20
+
21
+ def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
22
+ super().__init__()
23
+ self.softmax_scale = softmax_scale
24
+ self.dropout_p = attention_dropout
25
+
26
+ def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None,
27
+ max_s=None, need_weights=False):
28
+ """Implements the multihead softmax attention.
29
+ Arguments
30
+ ---------
31
+ qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
32
+ if unpadded: (nnz, 3, h, d)
33
+ key_padding_mask: a bool tensor of shape (B, S)
34
+ """
35
+
36
+ # qkv = qkv.to(torch.float16)
37
+
38
+ assert not need_weights
39
+ assert qkv.dtype in [torch.float16, torch.bfloat16]
40
+ assert qkv.is_cuda
41
+
42
+ if cu_seqlens is None:
43
+ batch_size = qkv.shape[0]
44
+ seqlen = qkv.shape[1]
45
+ if key_padding_mask is None:
46
+ qkv = rearrange(qkv, 'b s ... -> (b s) ...')
47
+ max_s = seqlen
48
+ cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
49
+ device=qkv.device)
50
+ output = flash_attn_varlen_qkvpacked_func(
51
+ qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
52
+ softmax_scale=self.softmax_scale, causal=causal
53
+ )
54
+ output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
55
+ else:
56
+ nheads = qkv.shape[-2]
57
+ x = rearrange(qkv, 'b s three h d -> b s (three h d)')
58
+ x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
59
+ x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
60
+ output_unpad = flash_attn_varlen_qkvpacked_func(
61
+ x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
62
+ softmax_scale=self.softmax_scale, causal=causal
63
+ )
64
+ output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
65
+ indices, batch_size, seqlen),
66
+ 'b s (h d) -> b s h d', h=nheads)
67
+ else:
68
+ assert max_s is not None
69
+ output = flash_attn_varlen_qkvpacked_func(
70
+ qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
71
+ softmax_scale=self.softmax_scale, causal=causal
72
+ )
73
+
74
+ return output, None
internvideo2.py ADDED
@@ -0,0 +1,779 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
5
+ from torch import nn
6
+
7
+ import torch.utils.checkpoint as checkpoint
8
+ from functools import partial
9
+ from einops import rearrange
10
+
11
+ from .pos_embed import get_3d_sincos_pos_embed, get_2d_sincos_pos_embed, get_1d_sincos_pos_embed, interpolate_pos_embed_internvideo2
12
+ from .flash_attention_class import FlashAttention
13
+
14
+ from transformers.utils import logging as error_logging
15
+
16
+ # Set up logging
17
+ error_logging.set_verbosity_error()
18
+
19
+ try:
20
+ from flash_attn.modules.mlp import Mlp as FusedMLP
21
+ except:
22
+ pass
23
+
24
+ try:
25
+ from flash_attn.ops.rms_norm import DropoutAddRMSNorm
26
+ except:
27
+ pass
28
+
29
+
30
+ class CrossAttention(nn.Module):
31
+ def __init__(
32
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
33
+ proj_drop=0., attn_head_dim=None, out_dim=None):
34
+ super().__init__()
35
+ if out_dim is None:
36
+ out_dim = dim
37
+ self.num_heads = num_heads
38
+ head_dim = dim // num_heads
39
+ if attn_head_dim is not None:
40
+ head_dim = attn_head_dim
41
+ all_head_dim = head_dim * self.num_heads
42
+ self.scale = qk_scale or head_dim ** -0.5
43
+ assert all_head_dim == dim
44
+
45
+ self.q = nn.Linear(dim, all_head_dim, bias=False)
46
+ self.k = nn.Linear(dim, all_head_dim, bias=False)
47
+ self.v = nn.Linear(dim, all_head_dim, bias=False)
48
+
49
+ if qkv_bias:
50
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
51
+ self.k_bias = nn.Parameter(torch.zeros(all_head_dim))
52
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
53
+ else:
54
+ self.q_bias = None
55
+ self.k_bias = None
56
+ self.v_bias = None
57
+
58
+ self.attn_drop = nn.Dropout(attn_drop)
59
+ self.proj = nn.Linear(all_head_dim, out_dim)
60
+ self.proj_drop = nn.Dropout(proj_drop)
61
+
62
+ def forward(self, x, k=None, v=None):
63
+ B, N, C = x.shape
64
+ N_k = k.shape[1]
65
+ N_v = v.shape[1]
66
+
67
+ q_bias, k_bias, v_bias = None, None, None
68
+ if self.q_bias is not None:
69
+ q_bias = self.q_bias
70
+ k_bias = self.k_bias
71
+ v_bias = self.v_bias
72
+
73
+ q = F.linear(input=x, weight=self.q.weight, bias=q_bias)
74
+ q = q.reshape(B, N, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0) # (B, N_head, N_q, dim)
75
+
76
+ k = F.linear(input=k, weight=self.k.weight, bias=k_bias)
77
+ k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
78
+
79
+ v = F.linear(input=v, weight=self.v.weight, bias=v_bias)
80
+ v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
81
+
82
+ q = q * self.scale
83
+ attn = (q @ k.transpose(-2, -1)) # (B, N_head, N_q, N_k)
84
+
85
+ attn = attn.softmax(dim=-1)
86
+ attn = self.attn_drop(attn)
87
+
88
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
89
+ x = self.proj(x)
90
+ x = self.proj_drop(x)
91
+
92
+ return x
93
+
94
+
95
+ class AttentiveBlock(nn.Module):
96
+
97
+ def __init__(self, dim, num_heads, qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
98
+ drop_path=0., norm_layer=nn.LayerNorm, attn_head_dim=None, out_dim=None):
99
+ super().__init__()
100
+
101
+ self.norm1_q = norm_layer(dim)
102
+ self.norm1_k = norm_layer(dim)
103
+ self.norm1_v = norm_layer(dim)
104
+ self.cross_attn = CrossAttention(
105
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
106
+ proj_drop=drop, attn_head_dim=attn_head_dim, out_dim=out_dim)
107
+
108
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
109
+
110
+ def forward(self, x_q, x_kv, pos_q, pos_k, bool_masked_pos, rel_pos_bias=None):
111
+ x_q = self.norm1_q(x_q + pos_q)
112
+ x_k = self.norm1_k(x_kv + pos_k)
113
+ x_v = self.norm1_v(x_kv)
114
+ x = self.cross_attn(x_q, k=x_k, v=x_v)
115
+
116
+ return x
117
+
118
+
119
+ class AttentionPoolingBlock(AttentiveBlock):
120
+
121
+ def forward(self, x):
122
+ # x_q = x.mean(1, keepdim=True)
123
+ x_q = x
124
+ x_kv, pos_q, pos_k = x, 0, 0
125
+ x = super().forward(x_q, x_kv, pos_q, pos_k, bool_masked_pos=None, rel_pos_bias=None)
126
+ x = x.squeeze(1)
127
+ return x
128
+
129
+
130
+ class RMSNorm(nn.Module):
131
+ def __init__(self, hidden_size, eps=1e-6):
132
+ super().__init__()
133
+ self.weight = nn.Parameter(torch.ones(hidden_size))
134
+ self.variance_epsilon = eps
135
+
136
+ def forward(self, hidden_states):
137
+ input_dtype = hidden_states.dtype
138
+ hidden_states = hidden_states.to(torch.float32)
139
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
140
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
141
+ return self.weight * hidden_states.to(input_dtype)
142
+
143
+
144
+ class LayerScale(nn.Module):
145
+ def __init__(self, dim, init_values=1e-5, inplace=False, force_fp32=False):
146
+ super().__init__()
147
+ self.inplace = inplace
148
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
149
+ self.force_fp32 = force_fp32
150
+
151
+ @torch.cuda.amp.autocast(enabled=False)
152
+ def forward(self, x):
153
+ if self.force_fp32:
154
+ output_type = x.dtype
155
+ out = x.float().mul_(self.gamma.float()) if self.inplace else x.float() * self.gamma.float()
156
+ return out.to(dtype=output_type)
157
+ else:
158
+ out = x.mul_(self.gamma) if self.inplace else x * self.gamma
159
+ return out
160
+
161
+
162
+ class Attention(nn.Module):
163
+ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., use_flash_attn=False,
164
+ causal=False, norm_layer=nn.LayerNorm, qk_normalization=False, use_fused_rmsnorm=False):
165
+ super().__init__()
166
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
167
+ self.num_heads = num_heads
168
+ head_dim = dim // num_heads
169
+ self.scale = head_dim ** -0.5
170
+
171
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
172
+ self.attn_drop = nn.Dropout(attn_drop)
173
+ self.proj = nn.Linear(dim, dim)
174
+ self.proj_drop = nn.Dropout(proj_drop)
175
+
176
+ self.use_flash_attn = use_flash_attn
177
+ if use_flash_attn:
178
+ self.causal = causal
179
+ self.inner_attn = FlashAttention(attention_dropout=attn_drop)
180
+
181
+ self.qk_normalization = qk_normalization
182
+ self.q_norm = norm_layer(dim) if qk_normalization else nn.Identity()
183
+ self.k_norm = norm_layer(dim) if qk_normalization else nn.Identity()
184
+ self.use_fused_rmsnorm = use_fused_rmsnorm
185
+
186
+ def _naive_attn(self, x):
187
+ B, N, C = x.shape
188
+ # print(x.shape, torch.cuda.memory_allocated(), torch.cuda.memory_allocated())
189
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
190
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
191
+
192
+ if self.qk_normalization:
193
+ B_, H_, N_, D_ = q.shape
194
+ q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
195
+ k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
196
+
197
+ attn = ((q * self.scale) @ k.transpose(-2, -1))
198
+ # attn = attn - attn.max(-1)[0].unsqueeze(-1) # in case of overflow for fp16
199
+ attn = attn.softmax(dim=-1)
200
+ attn = self.attn_drop(attn)
201
+ # print(torch.cuda.memory_allocated(), torch.cuda.memory_allocated())
202
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
203
+ x = self.proj(x)
204
+ x = self.proj_drop(x)
205
+ return x
206
+
207
+ def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
208
+
209
+ qkv = self.qkv(x)
210
+ qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads)
211
+
212
+ if self.qk_normalization:
213
+ q, k, v = qkv.unbind(2)
214
+ if self.use_fused_rmsnorm:
215
+ q = self.q_norm(q.flatten(-2, -1))[0].view(q.shape)
216
+ k = self.k_norm(k.flatten(-2, -1))[0].view(k.shape)
217
+ else:
218
+ q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
219
+ k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
220
+ qkv = torch.stack([q, k, v], dim=2)
221
+
222
+ context, _ = self.inner_attn(
223
+ qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=self.causal
224
+ )
225
+ outs = self.proj(rearrange(context, "b s h d -> b s (h d)"))
226
+ outs = self.proj_drop(outs)
227
+ return outs
228
+
229
+ def forward(self, x):
230
+ x = self._naive_attn(x) if not self.use_flash_attn else self._flash_attn(x)
231
+ return x
232
+
233
+
234
+ class Mlp(nn.Module):
235
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
236
+ """
237
+
238
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU,
239
+ bias=True, drop=0.):
240
+ super().__init__()
241
+ out_features = out_features or in_features
242
+ hidden_features = hidden_features or in_features
243
+ bias = to_2tuple(bias)
244
+ drop_probs = to_2tuple(drop)
245
+
246
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
247
+ self.act = act_layer()
248
+ self.drop1 = nn.Dropout(drop_probs[0])
249
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
250
+ self.drop2 = nn.Dropout(drop_probs[1])
251
+
252
+ def forward(self, x):
253
+ x = self.fc1(x)
254
+ x = self.act(x)
255
+ x = self.drop1(x)
256
+ x = self.fc2(x)
257
+ x = self.drop2(x)
258
+ return x
259
+
260
+
261
+ class Block(nn.Module):
262
+
263
+ def __init__(
264
+ self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None,
265
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_flash_attn=False, use_fused_mlp=False,
266
+ fused_mlp_heuristic=1, with_cp=False, qk_normalization=False, layerscale_no_force_fp32=False,
267
+ use_fused_rmsnorm=False):
268
+ super().__init__()
269
+
270
+ self.norm1 = norm_layer(dim)
271
+ self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
272
+ use_flash_attn=use_flash_attn, causal=False, norm_layer=norm_layer,
273
+ qk_normalization=qk_normalization,
274
+ use_fused_rmsnorm=use_fused_rmsnorm)
275
+ self.ls1 = LayerScale(dim, init_values=init_values,
276
+ force_fp32=(not layerscale_no_force_fp32)) if init_values else nn.Identity()
277
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
278
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
279
+
280
+ self.norm2 = norm_layer(dim)
281
+ mlp_hidden_dim = int(dim * mlp_ratio)
282
+ if use_fused_mlp:
283
+ # self.mlp = FusedMLP(in_features=dim, hidden_features=mlp_hidden_dim, heuristic=fused_mlp_heuristic)
284
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
285
+ else:
286
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
287
+ self.ls2 = LayerScale(dim, init_values=init_values,
288
+ force_fp32=(not layerscale_no_force_fp32)) if init_values else nn.Identity()
289
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
290
+
291
+ self.with_cp = with_cp
292
+ self.use_fused_rmsnorm = use_fused_rmsnorm
293
+
294
+ def forward(self, x, residual=None):
295
+
296
+ def _inner_forward(x, residual=None):
297
+ if self.use_fused_rmsnorm:
298
+ x, residual = self.norm1(x, residual)
299
+ x = self.drop_path1(self.ls1(self.attn(x)))
300
+ x, residual = self.norm2(x, residual)
301
+ x = self.drop_path2(self.ls2(self.mlp(x)))
302
+ return x, residual
303
+ else:
304
+ assert residual is None
305
+ x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
306
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
307
+ return x
308
+
309
+ if self.with_cp:
310
+ # print(f"\033[31m use_checkpoint [0m")
311
+ return checkpoint.checkpoint(_inner_forward, x, residual)
312
+ else:
313
+ return _inner_forward(x, residual=residual)
314
+
315
+
316
+ class PatchEmbed(nn.Module):
317
+ """ 3D Image to Patch Embedding
318
+ """
319
+
320
+ def __init__(
321
+ self, img_size=224, patch_size=16, in_chans=3, embed_dim=768,
322
+ num_frames=8, tubelet_size=1, norm_layer=None
323
+ ):
324
+ super().__init__()
325
+ img_size = to_2tuple(img_size)
326
+ patch_size = to_2tuple(patch_size)
327
+ self.img_size = img_size
328
+ self.patch_size = patch_size
329
+ self.grid_size = (
330
+ num_frames // tubelet_size,
331
+ img_size[0] // patch_size[0],
332
+ img_size[1] // patch_size[1]
333
+ ) # (T, H, W)
334
+ self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
335
+ self.num_img_patches = self.grid_size[1] * self.grid_size[2]
336
+
337
+ self.proj = nn.Conv3d(
338
+ in_channels=in_chans, out_channels=embed_dim,
339
+ kernel_size=(tubelet_size, patch_size[0], patch_size[1]),
340
+ stride=(tubelet_size, patch_size[0], patch_size[1])
341
+ )
342
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
343
+
344
+ def forward(self, x):
345
+ x = self.proj(x)
346
+ x = x.flatten(3).permute(0, 2, 3, 1) # B x C x T x HW => B x T x HW x C
347
+ x = self.norm(x)
348
+ return x
349
+
350
+
351
+ class Linear_Decoder(nn.Module):
352
+ def __init__(self, in_channels=1408, out_channels=3200,
353
+ norm_layer=nn.LayerNorm, clip_norm_type='l2'):
354
+ super().__init__()
355
+ self.clip_norm_type = clip_norm_type
356
+ # logger.info(f'Normalization Type: {clip_norm_type}')
357
+
358
+ self.head = nn.Linear(in_channels, out_channels)
359
+ self.norm = norm_layer(out_channels)
360
+
361
+ self.apply(self._init_weights)
362
+
363
+ def _init_weights(self, m):
364
+ if isinstance(m, nn.Linear):
365
+ nn.init.xavier_uniform_(m.weight)
366
+ if isinstance(m, nn.Linear) and m.bias is not None:
367
+ nn.init.constant_(m.bias, 0)
368
+ elif isinstance(m, nn.LayerNorm):
369
+ nn.init.constant_(m.bias, 0)
370
+ nn.init.constant_(m.weight, 1.0)
371
+
372
+ def forward(self, x):
373
+ x = self.norm(self.head(x))
374
+
375
+ if self.clip_norm_type == 'l2':
376
+ x = x / x.norm(dim=-1, keepdim=True)
377
+ elif self.clip_norm_type == 'none':
378
+ pass
379
+ else:
380
+ raise NotImplementedError
381
+
382
+ return x
383
+
384
+
385
+ class PretrainInternVideo2(nn.Module):
386
+ def __init__(
387
+ self,
388
+ in_chans: int = 3,
389
+ patch_size: int = 14,
390
+ img_size: int = 224,
391
+ qkv_bias: bool = False,
392
+ drop_path_rate: float = 0.25,
393
+ embed_dim: int = 1408,
394
+ num_heads: int = 16,
395
+ mlp_ratio: float = 48/11,
396
+ init_values: float = 1e-5,
397
+ qk_normalization: bool = True,
398
+ depth: int = 40,
399
+ use_flash_attn: bool = False,
400
+ use_fused_rmsnorm: bool = False,
401
+ use_fused_mlp: bool = False,
402
+ fused_mlp_heuristic: int = 1,
403
+ attn_pool_num_heads: int = 16,
404
+ clip_embed_dim: int = 768,
405
+ layerscale_no_force_fp32: bool = False,
406
+ num_frames: int = 8,
407
+ tubelet_size: int = 1,
408
+ sep_pos_embed: bool = False,
409
+ sep_image_video_pos_embed: bool = False,
410
+ use_checkpoint: bool = False,
411
+ checkpoint_num: int = 0,
412
+ # for unmasked teacher
413
+ clip_teacher_embed_dim: int = 3200,
414
+ clip_teacher_final_dim: int = 768, # if 0, not distill final features
415
+ clip_norm_type: str = 'l2',
416
+ clip_return_layer: int = 1,
417
+ clip_student_return_interval: int = 1,
418
+ ):
419
+ super().__init__()
420
+
421
+ self.num_frames = num_frames
422
+ # print(f'num_frames: {num_frames}')
423
+ self.tubelet_size = tubelet_size
424
+ assert use_flash_attn == use_fused_rmsnorm == use_fused_mlp, 'use_flash_attn, use_fused_rmsnorm and use_fused_mlp should be consistent'
425
+
426
+ self.use_flash_attn = use_flash_attn
427
+ self.embed_dim = embed_dim
428
+
429
+ self.depth = depth
430
+ self.clip_norm_type = clip_norm_type
431
+ self.return_index = []
432
+ for i in range(clip_return_layer):
433
+ self.return_index.append(depth - int(i * clip_student_return_interval) - 1)
434
+ # logger.info(f'Normalization Type: {clip_norm_type}')
435
+ # logger.info(f'Strudent Return Index: {self.return_index}')
436
+
437
+ if use_fused_rmsnorm:
438
+ norm_layer_for_blocks = partial(DropoutAddRMSNorm, eps=1e-6, prenorm=True)
439
+ else:
440
+ norm_layer_for_blocks = partial(RMSNorm, eps=1e-6)
441
+ self.norm_layer_for_blocks = norm_layer_for_blocks
442
+ self.patch_embed = PatchEmbed(
443
+ img_size, patch_size, in_chans, embed_dim,
444
+ num_frames=num_frames, tubelet_size=tubelet_size,
445
+ )
446
+ num_patches = self.patch_embed.num_patches
447
+ num_img_patches = self.patch_embed.num_img_patches
448
+
449
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
450
+
451
+ # stolen from https://github.com/facebookresearch/mae_st/blob/dc072aaaf640d06892e23a33b42223a994efe272/models_vit.py#L65-L73C17
452
+ self.sep_pos_embed = sep_pos_embed
453
+ self.sep_image_video_pos_embed = sep_image_video_pos_embed
454
+ if sep_pos_embed:
455
+ raise NotImplementedError
456
+ else:
457
+ if sep_image_video_pos_embed:
458
+ # logger.info("Use joint position embedding, for image and video we use different pos_embed.")
459
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
460
+ self.img_pos_embed = nn.Parameter(torch.zeros(1, num_img_patches + 1, embed_dim))
461
+ # for CLIP decoder
462
+ self.clip_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
463
+ self.clip_img_pos_embed = nn.Parameter(torch.zeros(1, num_img_patches + 1, embed_dim))
464
+ else:
465
+ # logger.info("Use joint position embedding, for image and video we use same pos_embed.")
466
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
467
+ self.clip_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
468
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
469
+ # choose which layer to use checkpoint
470
+ with_cp_list = [False] * depth
471
+ if use_checkpoint:
472
+ for idx in range(depth):
473
+ if idx < checkpoint_num:
474
+ with_cp_list[idx] = True
475
+ # logger.info(f"Droppath rate: {dpr}")
476
+ # logger.info(f"Checkpoint list: {with_cp_list}")
477
+
478
+ self.blocks = nn.ModuleList([
479
+ Block(embed_dim, num_heads, mlp_ratio, qkv_bias=qkv_bias,
480
+ norm_layer=norm_layer_for_blocks,
481
+ drop_path=dpr[i], init_values=init_values, attn_drop=0.,
482
+ use_flash_attn=use_flash_attn, use_fused_mlp=use_fused_mlp,
483
+ fused_mlp_heuristic=fused_mlp_heuristic,
484
+ with_cp=with_cp_list[i],
485
+ qk_normalization=qk_normalization,
486
+ layerscale_no_force_fp32=layerscale_no_force_fp32,
487
+ use_fused_rmsnorm=use_fused_rmsnorm)
488
+ for i in range(depth)])
489
+ self.clip_projector = AttentionPoolingBlock(
490
+ dim=embed_dim, num_heads=attn_pool_num_heads, qkv_bias=True, qk_scale=None,
491
+ drop=0., attn_drop=0., norm_layer=partial(nn.LayerNorm, eps=1e-5), out_dim=clip_embed_dim)
492
+
493
+ # CLIP decoder
494
+ self.clip_decoder = nn.ModuleList([
495
+ Linear_Decoder(
496
+ in_channels=embed_dim,
497
+ out_channels=clip_teacher_embed_dim,
498
+ norm_layer=partial(nn.LayerNorm, eps=1e-5),
499
+ clip_norm_type=clip_norm_type
500
+ ) for _ in range(clip_return_layer)
501
+ ])
502
+ self.final_clip_decoder = nn.Identity()
503
+ if clip_teacher_final_dim > 0:
504
+ self.final_clip_decoder = Linear_Decoder(
505
+ in_channels=clip_embed_dim,
506
+ out_channels=clip_teacher_final_dim,
507
+ norm_layer=partial(nn.LayerNorm, eps=1e-5),
508
+ clip_norm_type=clip_norm_type
509
+ )
510
+
511
+ self.init_pos_embed()
512
+ trunc_normal_(self.cls_token, std=.02)
513
+ self.apply(self._init_weights)
514
+ self.fix_init_weight()
515
+
516
+ def init_pos_embed(self):
517
+ # logger.info("Init pos_embed from sincos pos_embed")
518
+ if self.sep_pos_embed:
519
+ raise NotImplementedError
520
+ else:
521
+ # trunc_normal_(self.pos_embed, std=.02)
522
+ # trunc_normal_(self.clip_pos_embed, std=.02)
523
+ pos_embed = get_3d_sincos_pos_embed(
524
+ self.pos_embed.shape[-1],
525
+ self.patch_embed.grid_size[1], # height & weight
526
+ self.patch_embed.grid_size[0], # t_size
527
+ cls_token=True
528
+ )
529
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
530
+ self.clip_pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
531
+
532
+ if self.sep_image_video_pos_embed:
533
+ img_pos_embed = get_3d_sincos_pos_embed(
534
+ self.pos_embed.shape[-1],
535
+ self.patch_embed.grid_size[1], # height & weight
536
+ 1,
537
+ cls_token=True
538
+ )
539
+ self.img_pos_embed.data.copy_(torch.from_numpy(img_pos_embed).float().unsqueeze(0))
540
+ self.clip_img_pos_embed.data.copy_(torch.from_numpy(img_pos_embed).float().unsqueeze(0))
541
+
542
+ def _init_weights(self, m):
543
+ if isinstance(m, nn.Linear):
544
+ trunc_normal_(m.weight, std=.02)
545
+ if isinstance(m, nn.Linear) and m.bias is not None:
546
+ nn.init.constant_(m.bias, 0)
547
+ elif isinstance(m, nn.LayerNorm):
548
+ nn.init.constant_(m.bias, 0)
549
+ nn.init.constant_(m.weight, 1.0)
550
+
551
+ def fix_init_weight(self):
552
+ def rescale(param, layer_id):
553
+ param.div_(math.sqrt(2.0 * layer_id))
554
+
555
+ for layer_id, layer in enumerate(self.blocks):
556
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
557
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
558
+
559
+ @property
560
+ def dtype(self):
561
+ return self.patch_embed.proj.weight.dtype
562
+
563
+ def get_num_layers(self):
564
+ return len(self.blocks)
565
+
566
+ @torch.jit.ignore
567
+ def no_weight_decay(self):
568
+ return {
569
+ 'pos_embed',
570
+ 'pos_embed_spatial',
571
+ 'pos_embed_temporal',
572
+ 'pos_embed_cls',
573
+ 'img_pos_embed',
574
+ 'cls_token',
575
+ 'clip_pos_embed',
576
+ 'clip_pos_embed_spatial',
577
+ 'clip_pos_embed_temporal',
578
+ 'clip_pos_embed_cls',
579
+ 'clip_img_pos_embed'
580
+ }
581
+
582
+ # @torch.cuda.amp.autocast(enabled=False)
583
+ def forward(self, x, mask=None, use_image=False, x_vis_return_idx=-1, x_vis_only=False):
584
+ # print(0, x.shape)
585
+ x = self.patch_embed(x.type(self.dtype))
586
+ # print(f"x.shape: {x.shape} x.dtype: {x.dtype}, model.dtype: {self.dtype}")
587
+ B, T, L, C = x.shape # T: temporal; L: spatial
588
+ x = x.view([B, T * L, C]) # (B, T * L, C)
589
+
590
+ # append cls token
591
+ cls_tokens = self.cls_token.expand(B, -1, -1)
592
+ x = torch.cat((cls_tokens, x), dim=1) # (B, T * L + 1, C)
593
+ # print(1, x.shape)
594
+
595
+ # add pos_embed
596
+ if self.sep_pos_embed:
597
+ raise NotImplementedError
598
+ else:
599
+ if use_image:
600
+ # print('use image') # No.
601
+ if self.sep_image_video_pos_embed:
602
+ pos_embed = self.img_pos_embed
603
+ else:
604
+ # (1, num_img_patches + 1, embed_dim)
605
+ # print('origin pos_embed.shape:', self.pos_embed.shape)
606
+ cls_pos_embed = self.pos_embed[:, 0:1, :]
607
+ # print('cls_pos_embed.shape:', cls_pos_embed.shape)
608
+
609
+ img_pos_embed = self.pos_embed[:, 1:, :].view(1, self.num_frames, self.patch_embed.num_patches // self.num_frames, self.embed_dim).mean(dim=1)
610
+ # print('img_pos_embed.shape:', img_pos_embed.shape)
611
+
612
+ pos_embed = torch.cat([cls_pos_embed, img_pos_embed], dim=1)
613
+ # print('final img_pos_embed.shape:', pos_embed.shape)
614
+ else:
615
+ pos_embed = self.pos_embed
616
+ pos_embed = pos_embed[:, :x.shape[1], :]
617
+ x = x + pos_embed
618
+
619
+ # mask tokens, ~mask means visible
620
+ if mask is not None:
621
+ x = x[~mask].reshape(B, -1, C)
622
+ else:
623
+ x = x.reshape(B, -1, C)
624
+ residual = None
625
+ x_clip = []
626
+ for idx, blk in enumerate(self.blocks):
627
+ if isinstance(x, tuple) and len(x) == 2:
628
+ x, residual = x
629
+ # print(f"\033[31m这是{idx}, {x.shape}\033[0m")
630
+ x = blk(x, residual=residual)
631
+ # return intermediate features
632
+ if idx in self.return_index:
633
+ if isinstance(x, tuple) and len(x) == 2:
634
+ tmp_x, tmp_residual = x
635
+ if residual is not None:
636
+ x_clip.append(tmp_x + tmp_residual)
637
+ else:
638
+ x_clip.append(x)
639
+ if idx == (self.depth + x_vis_return_idx):
640
+ # print(f'idx = {idx} len(self.blocks)={len(self.blocks)}')
641
+ break
642
+
643
+ if isinstance(x, tuple) and len(x) == 2:
644
+ x, residual = x
645
+ if residual is not None:
646
+ x = x + residual
647
+
648
+ x_vis = x
649
+ # print(f'x_vis.shape:{x_vis.shape}')
650
+ if x_vis_only:
651
+ return x_vis
652
+
653
+ x_pool_vis = self.clip_projector(x_vis)
654
+ x_align = self.final_clip_decoder(x_pool_vis)
655
+ # print(3, x_pool_vis.shape)
656
+ # print(4, x_align.shape)
657
+
658
+ # align CLIP
659
+ x_clip = torch.stack(x_clip)
660
+ K, B, _, C_CLIP = x_clip.shape
661
+ # print(5, x_clip.shape)
662
+ # add pos_embed
663
+ if self.sep_pos_embed:
664
+ raise NotImplementedError
665
+ else:
666
+ if use_image:
667
+ if self.sep_image_video_pos_embed:
668
+ clip_pos_embed = self.clip_img_pos_embed
669
+ else:
670
+ # (1, num_img_patches + 1, embed_dim)
671
+ # print('origin pos_embed.shape:', self.pos_embed.shape)
672
+ clip_cls_pos_embed = self.clip_pos_embed[:, 0:1, :]
673
+ # print('cls_pos_embed.shape:', cls_pos_embed.shape)
674
+
675
+ clip_img_pos_embed = self.clip_pos_embed[:, 1:, :].view(1, self.num_frames, self.patch_embed.num_patches // self.num_frames, self.embed_dim).mean(dim=1)
676
+ # print('img_pos_embed.shape:', img_pos_embed.shape)
677
+
678
+ clip_pos_embed = torch.cat([clip_cls_pos_embed, clip_img_pos_embed], dim=1)
679
+ # print('final img_pos_embed.shape:', pos_embed.shape)
680
+
681
+ else:
682
+ clip_pos_embed = self.clip_pos_embed
683
+
684
+ clip_pos_embed = clip_pos_embed.repeat(B, 1, 1)
685
+ if mask is not None:
686
+ x_clip = x_clip + clip_pos_embed[~mask].view(B, -1, C_CLIP).unsqueeze(0).repeat(K, 1, 1, 1)
687
+ else:
688
+ clip_pos_embed = clip_pos_embed.unsqueeze(0).repeat(K, 1, 1, 1)
689
+ clip_pos_embed = clip_pos_embed[:, :, :x_clip.shape[2], :]
690
+ x_clip = x_clip + clip_pos_embed
691
+
692
+ # CLIP decoder
693
+ x_clip_align = []
694
+ for idx, clip_decoder in enumerate(self.clip_decoder):
695
+ x_clip_align.append(clip_decoder(x_clip[idx]))
696
+ x_clip_align = torch.stack(x_clip_align)
697
+
698
+ # print(f'x_vis.shape:{x_vis.shape}, x_pool_vis.shape:{x_pool_vis.shape}')
699
+ return x_vis, x_pool_vis, x_clip_align, x_align
700
+
701
+
702
+ def pretrain_internvideo2_1b_patch14_224(config):
703
+ # print(config.vision_encoder.num_frames)
704
+ model = PretrainInternVideo2(
705
+ in_chans=3, img_size=224, patch_size=14,
706
+ embed_dim=1408, depth=40, num_heads=16, mlp_ratio=48/11,
707
+ clip_embed_dim=config.vision_encoder.clip_embed_dim,
708
+ attn_pool_num_heads=16, qkv_bias=False,
709
+ drop_path_rate=0.25,
710
+ init_values=0.00001,
711
+ qk_normalization=True,
712
+ use_flash_attn=config.vision_encoder.get('use_flash_attn', True),
713
+ use_fused_rmsnorm=config.vision_encoder.get('use_fused_rmsnorm', True),
714
+ use_fused_mlp=config.vision_encoder.get('use_fused_mlp', True),
715
+ fused_mlp_heuristic=1,
716
+ layerscale_no_force_fp32=False,
717
+ num_frames=config.vision_encoder.num_frames,
718
+ tubelet_size=config.vision_encoder.tubelet_size,
719
+ sep_pos_embed=False,
720
+ sep_image_video_pos_embed=config.vision_encoder.sep_image_video_pos_embed,
721
+ use_checkpoint=config.vision_encoder.use_checkpoint,
722
+ checkpoint_num=config.vision_encoder.checkpoint_num,
723
+ clip_teacher_embed_dim=config.vision_encoder.clip_teacher_embed_dim,
724
+ clip_teacher_final_dim=config.vision_encoder.clip_teacher_final_dim,
725
+ clip_norm_type=config.vision_encoder.clip_norm_type,
726
+ clip_return_layer=config.vision_encoder.clip_return_layer,
727
+ clip_student_return_interval=config.vision_encoder.clip_student_return_interval,
728
+ )
729
+
730
+ if config.vision_encoder.pretrained is not None:
731
+ # logger.info(f"Loading pretrained weights from {config.vision_encoder.pretrained}")
732
+ state_dict = torch.load(config.vision_encoder.pretrained, map_location='cpu')
733
+ interpolate_pos_embed_internvideo2(state_dict, model, orig_t_size=8)
734
+ message = model.load_state_dict(state_dict, strict=False)
735
+ # logger.info(message)
736
+ else:
737
+ pass
738
+ # logger.info("No pretrained weights!!!")
739
+ return model
740
+
741
+
742
+
743
+ def pretrain_internvideo2_6b_patch14_224(config):
744
+ model = PretrainInternVideo2(
745
+ in_chans=3, img_size=224, patch_size=14,
746
+ embed_dim=3200, depth=48, num_heads=25, mlp_ratio=4,
747
+ clip_embed_dim=config.vision_encoder.clip_embed_dim,
748
+ attn_pool_num_heads=16, qkv_bias=False,
749
+ drop_path_rate=0.3,
750
+ init_values=0.00001,
751
+ qk_normalization=True,
752
+ use_flash_attn=config.vision_encoder.get('use_flash_attn', True),
753
+ use_fused_rmsnorm=config.vision_encoder.get('use_fused_rmsnorm', True),
754
+ use_fused_mlp=config.vision_encoder.get('use_fused_mlp', True),
755
+ fused_mlp_heuristic=1,
756
+ layerscale_no_force_fp32=False,
757
+ num_frames=config.vision_encoder.num_frames,
758
+ tubelet_size=config.vision_encoder.tubelet_size,
759
+ sep_pos_embed=False,
760
+ sep_image_video_pos_embed=config.vision_encoder.sep_image_video_pos_embed,
761
+ use_checkpoint=config.vision_encoder.use_checkpoint,
762
+ checkpoint_num=config.vision_encoder.checkpoint_num,
763
+ clip_teacher_embed_dim=config.vision_encoder.clip_teacher_embed_dim,
764
+ clip_teacher_final_dim=config.vision_encoder.clip_teacher_final_dim,
765
+ clip_norm_type=config.vision_encoder.clip_norm_type,
766
+ clip_return_layer=config.vision_encoder.clip_return_layer,
767
+ clip_student_return_interval=config.vision_encoder.clip_student_return_interval,
768
+ )
769
+
770
+ if config.vision_encoder.pretrained is not None:
771
+ # logger.info(f"Loading pretrained weights from {config.vision_encoder.pretrained}")
772
+ state_dict = torch.load(config.vision_encoder.pretrained, map_location='cpu')
773
+ interpolate_pos_embed_internvideo2(state_dict, model, orig_t_size=8)
774
+ msg = model.load_state_dict(state_dict, strict=False)
775
+ # logger.info(msg)
776
+ else:
777
+ pass
778
+ # logger.info("No pretrained weights!!!")
779
+ return model
internvideo2_clip_vision.py ADDED
@@ -0,0 +1,553 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
6
+ from timm.models.registry import register_model
7
+ from torch import nn
8
+
9
+ import torch.utils.checkpoint as checkpoint
10
+ from functools import partial
11
+ from einops import rearrange
12
+
13
+ from .pos_embed import get_3d_sincos_pos_embed, get_2d_sincos_pos_embed, get_1d_sincos_pos_embed
14
+ from .flash_attention_class import FlashAttention
15
+ from flash_attn.modules.mlp import FusedMLP
16
+ try:
17
+ from flash_attn.ops.rms_norm import DropoutAddRMSNorm
18
+ except:
19
+ pass
20
+
21
+ from transformers.utils import logging
22
+ import warnings
23
+ warnings.filterwarnings("ignore")
24
+ logger = logging.get_logger(__name__)
25
+
26
+
27
+ class CrossAttention(nn.Module):
28
+ def __init__(
29
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
30
+ proj_drop=0., attn_head_dim=None, out_dim=None):
31
+ super().__init__()
32
+ if out_dim is None:
33
+ out_dim = dim
34
+ self.num_heads = num_heads
35
+ head_dim = dim // num_heads
36
+ if attn_head_dim is not None:
37
+ head_dim = attn_head_dim
38
+ all_head_dim = head_dim * self.num_heads
39
+ self.scale = qk_scale or head_dim ** -0.5
40
+ assert all_head_dim == dim
41
+
42
+ self.q = nn.Linear(dim, all_head_dim, bias=False)
43
+ self.k = nn.Linear(dim, all_head_dim, bias=False)
44
+ self.v = nn.Linear(dim, all_head_dim, bias=False)
45
+
46
+ if qkv_bias:
47
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
48
+ self.k_bias = nn.Parameter(torch.zeros(all_head_dim))
49
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
50
+ else:
51
+ self.q_bias = None
52
+ self.k_bias = None
53
+ self.v_bias = None
54
+
55
+ self.attn_drop = nn.Dropout(attn_drop)
56
+ self.proj = nn.Linear(all_head_dim, out_dim)
57
+ self.proj_drop = nn.Dropout(proj_drop)
58
+
59
+ def forward(self, x, k=None, v=None):
60
+ B, N, C = x.shape
61
+ N_k = k.shape[1]
62
+ N_v = v.shape[1]
63
+
64
+ q_bias, k_bias, v_bias = None, None, None
65
+ if self.q_bias is not None:
66
+ q_bias = self.q_bias
67
+ k_bias = self.k_bias
68
+ v_bias = self.v_bias
69
+
70
+ q = F.linear(input=x, weight=self.q.weight, bias=q_bias)
71
+ q = q.reshape(B, N, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0) # (B, N_head, N_q, dim)
72
+
73
+ k = F.linear(input=k, weight=self.k.weight, bias=k_bias)
74
+ k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
75
+
76
+ v = F.linear(input=v, weight=self.v.weight, bias=v_bias)
77
+ v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
78
+
79
+ q = q * self.scale
80
+ attn = (q @ k.transpose(-2, -1)) # (B, N_head, N_q, N_k)
81
+
82
+ attn = attn.softmax(dim=-1)
83
+ attn = self.attn_drop(attn)
84
+
85
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
86
+ x = self.proj(x)
87
+ x = self.proj_drop(x)
88
+
89
+ return x
90
+
91
+
92
+ class AttentiveBlock(nn.Module):
93
+
94
+ def __init__(self, dim, num_heads, qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
95
+ drop_path=0., norm_layer=nn.LayerNorm, attn_head_dim=None, out_dim=None):
96
+ super().__init__()
97
+
98
+ self.norm1_q = norm_layer(dim)
99
+ self.norm1_k = norm_layer(dim)
100
+ self.norm1_v = norm_layer(dim)
101
+ self.cross_attn = CrossAttention(
102
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
103
+ proj_drop=drop, attn_head_dim=attn_head_dim, out_dim=out_dim)
104
+
105
+ if drop_path > 0.:
106
+ logger.info(f"Use DropPath in projector: {drop_path}")
107
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
108
+
109
+ def forward(self, x_q, x_kv, pos_q, pos_k, bool_masked_pos, rel_pos_bias=None):
110
+ x_q = self.norm1_q(x_q + pos_q)
111
+ x_k = self.norm1_k(x_kv + pos_k)
112
+ x_v = self.norm1_v(x_kv)
113
+ x = self.cross_attn(x_q, k=x_k, v=x_v)
114
+
115
+ return x
116
+
117
+
118
+ class AttentionPoolingBlock(AttentiveBlock):
119
+
120
+ def forward(self, x):
121
+ x_q = x.mean(1, keepdim=True)
122
+ x_kv, pos_q, pos_k = x, 0, 0
123
+ x = super().forward(x_q, x_kv, pos_q, pos_k, bool_masked_pos=None, rel_pos_bias=None)
124
+ x = x.squeeze(1)
125
+ return x
126
+
127
+
128
+ class RMSNorm(nn.Module):
129
+ def __init__(self, hidden_size, eps=1e-6):
130
+ super().__init__()
131
+ self.weight = nn.Parameter(torch.ones(hidden_size))
132
+ self.variance_epsilon = eps
133
+
134
+ def forward(self, hidden_states):
135
+ input_dtype = hidden_states.dtype
136
+ hidden_states = hidden_states.to(torch.float32)
137
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
138
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
139
+ return self.weight * hidden_states.to(input_dtype)
140
+
141
+
142
+ class LayerScale(nn.Module):
143
+ def __init__(self, dim, init_values=1e-5, inplace=False, force_fp32=False):
144
+ super().__init__()
145
+ self.inplace = inplace
146
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
147
+ self.force_fp32 = force_fp32
148
+
149
+ @torch.cuda.amp.autocast(enabled=False)
150
+ def forward(self, x):
151
+ if self.force_fp32:
152
+ output_type = x.dtype
153
+ out = x.float().mul_(self.gamma.float()) if self.inplace else x.float() * self.gamma.float()
154
+ return out.to(dtype=output_type)
155
+ else:
156
+ out = x.mul_(self.gamma) if self.inplace else x * self.gamma
157
+ return out
158
+
159
+
160
+ class Attention(nn.Module):
161
+ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., use_flash_attn=False,
162
+ causal=False, norm_layer=nn.LayerNorm, qk_normalization=False, use_fused_rmsnorm=False):
163
+ super().__init__()
164
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
165
+ self.num_heads = num_heads
166
+ head_dim = dim // num_heads
167
+ self.scale = head_dim ** -0.5
168
+
169
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
170
+ self.attn_drop = nn.Dropout(attn_drop)
171
+ self.proj = nn.Linear(dim, dim)
172
+ self.proj_drop = nn.Dropout(proj_drop)
173
+
174
+ self.use_flash_attn = use_flash_attn
175
+ if use_flash_attn:
176
+ self.causal = causal
177
+ self.inner_attn = FlashAttention(attention_dropout=attn_drop)
178
+
179
+ self.qk_normalization = qk_normalization
180
+ self.q_norm = norm_layer(dim) if qk_normalization else nn.Identity()
181
+ self.k_norm = norm_layer(dim) if qk_normalization else nn.Identity()
182
+ self.use_fused_rmsnorm = use_fused_rmsnorm
183
+
184
+ def _naive_attn(self, x):
185
+ B, N, C = x.shape
186
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
187
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
188
+
189
+ if self.qk_normalization:
190
+ B_, H_, N_, D_ = q.shape
191
+ q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
192
+ k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
193
+
194
+ attn = ((q * self.scale) @ k.transpose(-2, -1))
195
+ # attn = attn - attn.max(-1)[0].unsqueeze(-1) # in case of overflow for fp16
196
+ attn = attn.softmax(dim=-1)
197
+ attn = self.attn_drop(attn)
198
+
199
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
200
+ x = self.proj(x)
201
+ x = self.proj_drop(x)
202
+ return x
203
+
204
+ def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
205
+
206
+ qkv = self.qkv(x)
207
+ qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads)
208
+
209
+ if self.qk_normalization:
210
+ q, k, v = qkv.unbind(2)
211
+ if self.use_fused_rmsnorm:
212
+ q = self.q_norm(q.flatten(-2, -1))[0].view(q.shape)
213
+ k = self.k_norm(k.flatten(-2, -1))[0].view(k.shape)
214
+ else:
215
+ q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
216
+ k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
217
+ qkv = torch.stack([q, k, v], dim=2)
218
+
219
+ context, _ = self.inner_attn(
220
+ qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=self.causal
221
+ )
222
+ outs = self.proj(rearrange(context, "b s h d -> b s (h d)"))
223
+ outs = self.proj_drop(outs)
224
+ return outs
225
+
226
+ def forward(self, x):
227
+ x = self._naive_attn(x) if not self.use_flash_attn else self._flash_attn(x)
228
+ return x
229
+
230
+
231
+ class Mlp(nn.Module):
232
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
233
+ """
234
+
235
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU,
236
+ bias=True, drop=0.):
237
+ super().__init__()
238
+ out_features = out_features or in_features
239
+ hidden_features = hidden_features or in_features
240
+ bias = to_2tuple(bias)
241
+ drop_probs = to_2tuple(drop)
242
+
243
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
244
+ self.act = act_layer()
245
+ self.drop1 = nn.Dropout(drop_probs[0])
246
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
247
+ self.drop2 = nn.Dropout(drop_probs[1])
248
+
249
+ def forward(self, x):
250
+ x = self.fc1(x)
251
+ x = self.act(x)
252
+ x = self.drop1(x)
253
+ x = self.fc2(x)
254
+ x = self.drop2(x)
255
+ return x
256
+
257
+
258
+ class Block(nn.Module):
259
+
260
+ def __init__(
261
+ self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None,
262
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_flash_attn=False, use_fused_mlp=False,
263
+ fused_mlp_heuristic=1, with_cp=False, qk_normalization=False, layerscale_no_force_fp32=False,
264
+ use_fused_rmsnorm=False):
265
+ super().__init__()
266
+
267
+ self.norm1 = norm_layer(dim)
268
+ self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
269
+ use_flash_attn=use_flash_attn, causal=False, norm_layer=norm_layer,
270
+ qk_normalization=qk_normalization,
271
+ use_fused_rmsnorm=use_fused_rmsnorm)
272
+ self.ls1 = LayerScale(dim, init_values=init_values,
273
+ force_fp32=(not layerscale_no_force_fp32)) if init_values else nn.Identity()
274
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
275
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
276
+
277
+ self.norm2 = norm_layer(dim)
278
+ mlp_hidden_dim = int(dim * mlp_ratio)
279
+ if use_fused_mlp:
280
+ self.mlp = FusedMLP(in_features=dim, hidden_features=mlp_hidden_dim, heuristic=fused_mlp_heuristic)
281
+ else:
282
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
283
+ self.ls2 = LayerScale(dim, init_values=init_values,
284
+ force_fp32=(not layerscale_no_force_fp32)) if init_values else nn.Identity()
285
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
286
+
287
+ self.with_cp = with_cp
288
+ self.use_fused_rmsnorm = use_fused_rmsnorm
289
+
290
+ def forward(self, x, residual=None):
291
+
292
+ def _inner_forward(x, residual=None):
293
+ if self.use_fused_rmsnorm:
294
+ x, residual = self.norm1(x, residual)
295
+ x = self.drop_path1(self.ls1(self.attn(x)))
296
+ x, residual = self.norm2(x, residual)
297
+ x = self.drop_path2(self.ls2(self.mlp(x)))
298
+ return x, residual
299
+ else:
300
+ assert residual is None
301
+ x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
302
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
303
+ return x
304
+
305
+ if self.with_cp:
306
+ return checkpoint.checkpoint(_inner_forward, x, residual)
307
+ else:
308
+ return _inner_forward(x, residual=residual)
309
+
310
+
311
+ class PatchEmbed(nn.Module):
312
+ """ 3D Image to Patch Embedding
313
+ """
314
+
315
+ def __init__(
316
+ self, img_size=224, patch_size=16, in_chans=3, embed_dim=768,
317
+ num_frames=8, tubelet_size=1, norm_layer=None
318
+ ):
319
+ super().__init__()
320
+ img_size = to_2tuple(img_size)
321
+ patch_size = to_2tuple(patch_size)
322
+ self.img_size = img_size
323
+ self.patch_size = patch_size
324
+ self.tubelet_size = tubelet_size
325
+ self.grid_size = (
326
+ num_frames // tubelet_size,
327
+ img_size[0] // patch_size[0],
328
+ img_size[1] // patch_size[1]
329
+ ) # (T, H, W)
330
+ self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
331
+
332
+ self.proj = nn.Conv3d(
333
+ in_channels=in_chans, out_channels=embed_dim,
334
+ kernel_size=(tubelet_size, patch_size[0], patch_size[1]),
335
+ stride=(tubelet_size, patch_size[0], patch_size[1])
336
+ )
337
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
338
+
339
+ def forward(self, x):
340
+ x = self.proj(x)
341
+ x = x.flatten(3).permute(0, 2, 3, 1) # B x C x T x HW => B x T x HW x C
342
+ x = self.norm(x)
343
+ return x
344
+
345
+
346
+ class InternVideo2(nn.Module):
347
+ def __init__(
348
+ self,
349
+ in_chans: int = 3,
350
+ patch_size: int = 14,
351
+ img_size: int = 224,
352
+ qkv_bias: bool = False,
353
+ drop_path_rate: float = 0.25, # may need ablation
354
+ head_drop_path_rate: float = 0.,
355
+ embed_dim: int = 1408,
356
+ num_heads: int = 16,
357
+ mlp_ratio: float = 48/11,
358
+ init_values: float = 1e-5, # may need ablation
359
+ qk_normalization: bool = True,
360
+ depth: int = 40,
361
+ use_flash_attn: bool = True,
362
+ use_fused_rmsnorm: bool = True,
363
+ use_fused_mlp: bool = True,
364
+ fused_mlp_heuristic: int = 1,
365
+ attn_pool_num_heads: int = 16,
366
+ clip_embed_dim: int = 768,
367
+ layerscale_no_force_fp32: bool = False, # when True for training?
368
+ num_frames: int = 8,
369
+ tubelet_size: int = 1,
370
+ sep_pos_embed: bool = False,
371
+ use_checkpoint: bool = False,
372
+ checkpoint_num: int = 0,
373
+ ):
374
+ super().__init__()
375
+
376
+ assert use_flash_attn == use_fused_rmsnorm == use_fused_mlp, logger.info(
377
+ 'use_flash_attn, use_fused_rmsnorm and use_fused_mlp should be consistent')
378
+
379
+ self.use_flash_attn = use_flash_attn
380
+ self.embed_dim = embed_dim
381
+ self.T = num_frames // tubelet_size
382
+
383
+ if use_fused_rmsnorm:
384
+ norm_layer_for_blocks = partial(DropoutAddRMSNorm, eps=1e-6, prenorm=True)
385
+ else:
386
+ norm_layer_for_blocks = partial(RMSNorm, eps=1e-6)
387
+ self.norm_layer_for_blocks = norm_layer_for_blocks
388
+ self.patch_embed = PatchEmbed(
389
+ img_size, patch_size, in_chans, embed_dim,
390
+ num_frames=num_frames, tubelet_size=tubelet_size,
391
+ )
392
+ num_patches = self.patch_embed.num_patches
393
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
394
+
395
+ # stolen from https://github.com/facebookresearch/mae_st/blob/dc072aaaf640d06892e23a33b42223a994efe272/models_vit.py#L65-L73C17
396
+ self.sep_pos_embed = sep_pos_embed
397
+ if sep_pos_embed:
398
+ logger.info("Use seperable position embedding")
399
+ grid_size = self.patch_embed.grid_size
400
+ self.grid_size = grid_size
401
+ self.pos_embed_spatial = nn.Parameter(torch.zeros(1, grid_size[1] * grid_size[2], embed_dim))
402
+ self.pos_embed_temporal = nn.Parameter(torch.zeros(1, grid_size[0], embed_dim))
403
+ self.pos_embed_cls = nn.Parameter(torch.zeros(1, 1, embed_dim))
404
+ else:
405
+ logger.info("Use joint position embedding")
406
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
407
+
408
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
409
+ # choose which layer to use checkpoint
410
+ with_cp_list = [False] * depth
411
+ if use_checkpoint:
412
+ for idx in range(depth):
413
+ if idx < checkpoint_num:
414
+ with_cp_list[idx] = True
415
+ logger.info(f"Droppath rate: {dpr}")
416
+ logger.info(f"Checkpoint list: {with_cp_list}")
417
+
418
+ self.blocks = nn.ModuleList([
419
+ Block(embed_dim, num_heads, mlp_ratio, qkv_bias=qkv_bias,
420
+ norm_layer=norm_layer_for_blocks,
421
+ drop_path=dpr[i], init_values=init_values, attn_drop=0.,
422
+ use_flash_attn=use_flash_attn, use_fused_mlp=use_fused_mlp,
423
+ fused_mlp_heuristic=fused_mlp_heuristic,
424
+ with_cp=with_cp_list[i],
425
+ qk_normalization=qk_normalization,
426
+ layerscale_no_force_fp32=layerscale_no_force_fp32,
427
+ use_fused_rmsnorm=use_fused_rmsnorm)
428
+ for i in range(depth)])
429
+ self.clip_projector = AttentionPoolingBlock(
430
+ dim=embed_dim, num_heads=attn_pool_num_heads, qkv_bias=True, qk_scale=None,
431
+ drop=0., attn_drop=0., drop_path=head_drop_path_rate,
432
+ norm_layer=partial(nn.LayerNorm, eps=1e-5), out_dim=clip_embed_dim
433
+ )
434
+
435
+ self.fc_norm = nn.Identity()
436
+
437
+ self.init_pos_embed()
438
+ trunc_normal_(self.cls_token, std=.02)
439
+ self.apply(self._init_weights)
440
+ self.fix_init_weight()
441
+
442
+ def init_pos_embed(self):
443
+ logger.info("Init pos_embed from sincos pos_embed")
444
+ if self.sep_pos_embed:
445
+ # trunc_normal_(self.pos_embed_spatial, std=.02)
446
+ # trunc_normal_(self.pos_embed_temporal, std=.02)
447
+ # trunc_normal_(self.pos_embed_cls, std=.02)
448
+ pos_embed_spatial = get_2d_sincos_pos_embed(
449
+ self.pos_embed_spatial.shape[-1],
450
+ self.patch_embed.grid_size[1], # height & weight
451
+ )
452
+ self.pos_embed_spatial.data.copy_(torch.from_numpy(pos_embed_spatial).float().unsqueeze(0))
453
+ pos_embed_temporal = get_1d_sincos_pos_embed(
454
+ self.pos_embed_spatial.shape[-1],
455
+ self.patch_embed.grid_size[0], # t_size
456
+ )
457
+ self.pos_embed_temporal.data.copy_(torch.from_numpy(pos_embed_temporal).float().unsqueeze(0))
458
+ else:
459
+ # trunc_normal_(self.pos_embed, std=.02)
460
+ pos_embed = get_3d_sincos_pos_embed(
461
+ self.pos_embed.shape[-1],
462
+ self.patch_embed.grid_size[1], # height & weight
463
+ self.patch_embed.grid_size[0], # t_size
464
+ cls_token=True
465
+ )
466
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
467
+
468
+ def _init_weights(self, m):
469
+ if isinstance(m, nn.Linear):
470
+ trunc_normal_(m.weight, std=.02)
471
+ if isinstance(m, nn.Linear) and m.bias is not None:
472
+ nn.init.constant_(m.bias, 0)
473
+ elif isinstance(m, nn.LayerNorm):
474
+ nn.init.constant_(m.bias, 0)
475
+ nn.init.constant_(m.weight, 1.0)
476
+
477
+ def fix_init_weight(self):
478
+ def rescale(param, layer_id):
479
+ param.div_(math.sqrt(2.0 * layer_id))
480
+
481
+ for layer_id, layer in enumerate(self.blocks):
482
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
483
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
484
+
485
+ @property
486
+ def dtype(self):
487
+ return self.patch_embed.proj.weight.dtype
488
+
489
+ def get_num_layers(self):
490
+ return len(self.blocks)
491
+
492
+ @torch.jit.ignore
493
+ def no_weight_decay(self):
494
+ return {
495
+ 'pos_embed',
496
+ 'pos_embed_spatial',
497
+ 'pos_embed_temporal',
498
+ 'pos_embed_cls',
499
+ 'cls_token'
500
+ }
501
+
502
+ def forward(self, x, use_image=False):
503
+ x = self.patch_embed(x.type(self.dtype))
504
+ B, T, L, C = x.shape # T: temporal; L: spatial
505
+ x = x.view([B, T * L, C])
506
+
507
+ # append cls token
508
+ cls_tokens = self.cls_token.expand(B, -1, -1)
509
+ x = torch.cat((cls_tokens, x), dim=1)
510
+
511
+ # add pos_embed
512
+ if self.sep_pos_embed:
513
+ if use_image:
514
+ pos_embed = self.pos_embed_spatial
515
+ else:
516
+ pos_embed = self.pos_embed_spatial.repeat(
517
+ 1, self.grid_size[0], 1
518
+ ) + torch.repeat_interleave(
519
+ self.pos_embed_temporal,
520
+ self.grid_size[1] * self.grid_size[2],
521
+ dim=1,
522
+ )
523
+ pos_embed = torch.cat(
524
+ [
525
+ self.pos_embed_cls.expand(pos_embed.shape[0], -1, -1),
526
+ pos_embed,
527
+ ],
528
+ 1,
529
+ )
530
+ else:
531
+ if use_image:
532
+ cls_pos_embed = self.pos_embed[:, :1, :]
533
+ img_pos_embed = self.pos_embed[:, 1:, :].view(1, self.T, L, C).mean(dim=1)
534
+ pos_embed = torch.cat([cls_pos_embed, img_pos_embed], dim=1)
535
+ else:
536
+ pos_embed = self.pos_embed
537
+
538
+ x = x + pos_embed
539
+
540
+ residual = None
541
+ for blk in self.blocks:
542
+ if isinstance(x, tuple) and len(x) == 2:
543
+ x, residual = x
544
+ x = blk(x, residual=residual)
545
+ if isinstance(x, tuple) and len(x) == 2:
546
+ x, residual = x
547
+ if residual is not None:
548
+ x = x + residual
549
+
550
+ x = self.clip_projector(x)
551
+
552
+ x = self.fc_norm(x)
553
+ return x
mobile_clip.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # For licensing see accompanying LICENSE file.
3
+ # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4
+ #
5
+ import math
6
+ from typing import Optional, Sequence
7
+
8
+ import torch
9
+ from torch import Tensor, nn
10
+
11
+ from typing import Dict
12
+ import open_clip
13
+
14
+ from .mobile_clip_transformer import (
15
+ PositionalEmbedding,
16
+ TransformerEncoder,
17
+ get_normalization_layer,
18
+ )
19
+
20
+
21
+ class TextTransformer(nn.Module):
22
+ def __init__(self, cfg: dict, projection_dim: int, *args, **kwargs) -> None:
23
+ super().__init__()
24
+
25
+ model_dim = cfg["dim"]
26
+ no_scale_embedding = cfg.get("no_scale_embedding", False)
27
+ no_pos_embedding = cfg.get("no_pos_embedding", False)
28
+ embed_dropout = cfg.get("embed_dropout", 0.0)
29
+ norm_layer = cfg["norm_layer"]
30
+ variant = cfg["model_name"]
31
+ self.vocab_size = cfg["vocab_size"]
32
+ self.projection_dim = projection_dim
33
+
34
+ # Token embedding layer
35
+ self.embedding_layer = nn.Embedding(
36
+ embedding_dim=model_dim, num_embeddings=self.vocab_size
37
+ )
38
+ self.embed_scale = 1.0 if no_scale_embedding else model_dim**-0.5
39
+
40
+ # Context length
41
+ context_length = cfg["context_length"]
42
+ assert (
43
+ context_length is not None
44
+ ), "Context length can't be None. Please set value accordingly."
45
+
46
+ self.positional_embedding = (
47
+ None
48
+ if no_pos_embedding
49
+ else PositionalEmbedding(
50
+ num_embeddings=context_length, embedding_dim=model_dim
51
+ )
52
+ )
53
+
54
+ self.embedding_dropout = nn.Dropout(p=embed_dropout)
55
+
56
+ # Transformer layer
57
+ n_transformer_layers = cfg["n_transformer_layers"]
58
+
59
+ # FFN multipliers for transformer layer
60
+ ffn_multipliers = cfg["ffn_multiplier_per_layer"]
61
+ if isinstance(ffn_multipliers, (float, int)):
62
+ ffn_multipliers = [ffn_multipliers] * n_transformer_layers
63
+
64
+ if not isinstance(ffn_multipliers, Sequence):
65
+ Warning(
66
+ "{} expects FFN multipliers as a list, whose length is the same as"
67
+ " number of transformer layers. Got: {}".format(
68
+ self.__class__.__name__, type(ffn_multipliers)
69
+ )
70
+ )
71
+ elif (
72
+ isinstance(ffn_multipliers, Sequence)
73
+ and len(ffn_multipliers) != n_transformer_layers
74
+ ):
75
+ Warning(
76
+ "We need FFN multiplier for each transformer layer. Got {} ffn"
77
+ " multipliers while number of transformer layers = {}".format(
78
+ len(ffn_multipliers), n_transformer_layers
79
+ )
80
+ )
81
+ ffn_dims = [
82
+ int(math.ceil(model_dim * ffn_mult / 16.0) * 16.0)
83
+ for ffn_mult in ffn_multipliers
84
+ ]
85
+
86
+ # Heads for transformer layers
87
+ mha_heads = cfg["n_heads_per_layer"]
88
+ if isinstance(mha_heads, int):
89
+ mha_heads = [mha_heads] * n_transformer_layers
90
+
91
+ if not isinstance(mha_heads, Sequence):
92
+ Warning(
93
+ "{} expects MHA heads as a list, whose length is the same as number of "
94
+ "transformer layers. Got: {}".format(
95
+ self.__class__.__name__, type(mha_heads)
96
+ )
97
+ )
98
+ elif isinstance(mha_heads, Sequence) and len(mha_heads) != n_transformer_layers:
99
+ Warning(
100
+ "{} needs MHA heads for each transformer layer. Got {} mha heads while"
101
+ " number of transformer layers = {}".format(
102
+ self.__class__.__name__, len(mha_heads), n_transformer_layers
103
+ )
104
+ )
105
+
106
+ if variant == "base":
107
+ self.transformer = nn.ModuleList(
108
+ [
109
+ TransformerEncoder(
110
+ embed_dim=model_dim,
111
+ num_heads=mha_heads[layer_idx],
112
+ ffn_latent_dim=ffn_dims[layer_idx],
113
+ transformer_norm_layer=norm_layer,
114
+ )
115
+ for layer_idx in range(n_transformer_layers)
116
+ ]
117
+ )
118
+ elif variant == "mct":
119
+ raise NotImplementedError
120
+ else:
121
+ raise ValueError("Unrecognized text encoder variant {}".format(variant))
122
+
123
+ self.final_layer_norm = get_normalization_layer(
124
+ num_features=model_dim, norm_type=norm_layer
125
+ )
126
+
127
+ self.projection_layer = nn.Parameter(
128
+ torch.empty(model_dim, self.projection_dim)
129
+ )
130
+ self.model_dim = model_dim
131
+ self.causal_masking = cfg["causal_masking"]
132
+
133
+ def forward_embedding(self, text_tokens: Tensor) -> Tensor:
134
+ """Return text embedding for all tokens.
135
+
136
+ Args:
137
+ text_tokens: a tensor of token indices. Shape: [batch_size, context_length]
138
+
139
+ Returns:
140
+ A tensor of [batch_size, context_length, hidden_dim].
141
+ """
142
+ # [batch_size, context_length] --> [batch_size, context_length, hidden_dim]
143
+ token_emb = self.embedding_layer(text_tokens)
144
+ seq_len = token_emb.shape[1]
145
+ if self.positional_embedding is not None:
146
+ token_emb = token_emb + self.positional_embedding(seq_len).to(
147
+ token_emb.dtype
148
+ )
149
+ token_emb = self.embedding_dropout(token_emb)
150
+ return token_emb
151
+
152
+ def build_attention_mask(self, context_length: int, batch_size: int) -> Tensor:
153
+ """Build causal attention mask [batch_size, context_length, context_length]."""
154
+ # Build mask with full attention between the tokens
155
+ # pytorch uses additive attention mask; fill with -inf
156
+ mask = torch.empty(context_length, context_length)
157
+ mask.fill_(float("-inf"))
158
+ mask.triu_(1) # zero out the lower diagonal
159
+ mask = mask.unsqueeze(0) # add dummy batch dimension
160
+ mask = mask.expand(batch_size, -1, -1)
161
+ return mask
162
+
163
+ def encode_text(
164
+ self,
165
+ text_tokens: Tensor,
166
+ key_padding_mask: Optional[Tensor] = None,
167
+ return_all_tokens: bool = False,
168
+ *args,
169
+ **kwargs
170
+ ) -> Tensor:
171
+ """Return text token embeddings.
172
+
173
+ Args:
174
+ text_tokens: a tensor of token indices. Shape: [batch_size, context_length]
175
+ key_padding_mask: a tensor of boolean values as the padding mask.
176
+ Shape: [batch_size, context_length]
177
+ return_all_tokens: a boolean flag to return all tokens, defaults to False
178
+ to return only EOT token embedding.
179
+ Returns:
180
+ A tensor of [batch_size, context_length, hidden_dim] if return_all_tokens is
181
+ True, otherwise a tensor of [batch_size, hidden_dim].
182
+ """
183
+ # Discrete tokens to continuous embeddings
184
+ # [batch_size, context_length] --> [batch_size, context_length, hidden_dim]
185
+ token_emb = self.forward_embedding(text_tokens)
186
+
187
+ # [1, context_length, context_length]
188
+ attn_mask = None
189
+ if self.causal_masking:
190
+ attn_mask = self.build_attention_mask(
191
+ context_length=text_tokens.shape[1], batch_size=text_tokens.shape[0]
192
+ )
193
+ attn_mask = attn_mask.to(device=token_emb.device, dtype=token_emb.dtype)
194
+ key_padding_mask = None
195
+
196
+ for layer in self.transformer:
197
+ token_emb = layer(
198
+ token_emb,
199
+ key_padding_mask=key_padding_mask,
200
+ attn_mask=attn_mask,
201
+ )
202
+
203
+ # Apply layer norm
204
+ token_emb = self.final_layer_norm(token_emb)
205
+
206
+ if return_all_tokens:
207
+ return token_emb
208
+
209
+ # Take features from the eot embedding (eot_token is the highest number in each sequence)
210
+ token_emb = token_emb[
211
+ torch.arange(text_tokens.shape[0]), text_tokens.argmax(dim=-1)
212
+ ]
213
+
214
+ token_emb = token_emb @ self.projection_layer
215
+ return token_emb
216
+
217
+ def forward(
218
+ self,
219
+ text_tokens: Tensor,
220
+ key_padding_mask: Optional[Tensor] = None,
221
+ return_all_tokens: bool = False,
222
+ *args,
223
+ **kwargs
224
+ ) -> Tensor:
225
+ # Image-text pair data with single caption
226
+ # [B, CL] --> [B, d]
227
+ text_tokens = self.encode_text(
228
+ text_tokens=text_tokens,
229
+ key_padding_mask=key_padding_mask,
230
+ return_all_tokens=return_all_tokens,
231
+ *args,
232
+ **kwargs
233
+ )
234
+ return text_tokens
235
+
236
+
237
+ class ClipTokenizer(nn.Module):
238
+ def __init__(self, cfg, *args, **kwargs):
239
+ super().__init__()
240
+ self.context_length = cfg["text_cfg"]["context_length"]
241
+ model_name = getattr(cfg["text_cfg"], "open_clip_tokenizer", "ViT-B-16")
242
+ self.tokenizer = open_clip.get_tokenizer(model_name)
243
+
244
+ def get_vocab_size(self) -> int:
245
+ return len(self.tokenizer.encoder)
246
+
247
+ def get_encodings(self) -> Dict[str, int]:
248
+ return self.tokenizer.encoder
249
+
250
+ def get_eot_token(self) -> int:
251
+ # Tokenizing an empty string returns a list [sot_id, eot_id]
252
+ return self.tokenizer("")[1]
253
+
254
+ def get_sot_token(self) -> int:
255
+ # Tokenizing an empty string returns a list [sot_id, eot_id]
256
+ return self.tokenizer("")[0]
257
+
258
+ def forward(self, input_sentence: str, *args, **kwargs) -> Tensor:
259
+ # tokenizer returns indices as a string
260
+ tokenized_sentence = self.tokenizer(input_sentence, self.context_length)
261
+ assert (
262
+ tokenized_sentence.shape[-1] == self.context_length
263
+ ), "Tokenized tensor should be exactly `context_length` long."
264
+ return tokenized_sentence
mobile_clip_transformer.py ADDED
@@ -0,0 +1,449 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # For licensing see accompanying LICENSE file.
3
+ # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4
+ #
5
+ """
6
+ Implementation of the following modules is borrowed from ml-cvnets repo:
7
+ https://github.com/apple/ml-cvnets/blob/main/cvnets/layers/multi_head_attention.py
8
+ https://github.com/apple/ml-cvnets/blob/main/cvnets/text_encoders/transformer.py
9
+
10
+ Please see ACKNOWLEDGEMENTS for license details.
11
+ """
12
+
13
+ from typing import List, Optional, Union
14
+
15
+ import torch
16
+ from torch import Size, Tensor, nn
17
+ from torch.nn import functional as F
18
+ from torchvision.ops import StochasticDepth
19
+
20
+
21
+ class LayerNormFP32(nn.LayerNorm):
22
+ """
23
+ Applies `Layer Normalization <https://arxiv.org/abs/1607.06450>`_ over a input tensor with FP32 precision
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ normalized_shape: Union[int, List[int], Size],
29
+ eps: Optional[float] = 1e-5,
30
+ elementwise_affine: Optional[bool] = True,
31
+ *args,
32
+ **kwargs,
33
+ ):
34
+ super().__init__(
35
+ normalized_shape=normalized_shape,
36
+ eps=eps,
37
+ elementwise_affine=elementwise_affine,
38
+ *args,
39
+ **kwargs,
40
+ )
41
+
42
+ def forward(self, x: Tensor) -> Tensor:
43
+ # Convert input from dtype X to FP32 and perform normalization operation.
44
+ # This may help with underflow/overflow issues that we typically see with normalization layers
45
+ inp_dtype = x.dtype
46
+ return super().forward(x.to(torch.float32)).to(inp_dtype)
47
+
48
+
49
+ def get_normalization_layer(norm_type, num_features):
50
+ if norm_type == "layer_norm":
51
+ return nn.LayerNorm(num_features)
52
+ elif norm_type == "layer_norm_fp32":
53
+ return LayerNormFP32(num_features)
54
+ else:
55
+ raise NotImplementedError(f"Option: {norm_type} not supported.")
56
+
57
+
58
+ class PositionalEmbedding(nn.Module):
59
+ def __init__(
60
+ self,
61
+ num_embeddings: int,
62
+ embedding_dim: int,
63
+ padding_idx: Optional[int] = None,
64
+ is_learnable: Optional[bool] = False,
65
+ interpolation_mode: Optional[str] = "bilinear",
66
+ *args,
67
+ **kwargs,
68
+ ):
69
+ super().__init__()
70
+ # Add other pos embedding here and logic to choose between them
71
+ module = LearnablePositionalEmbedding
72
+
73
+ self.pos_embed = module(
74
+ num_embeddings=num_embeddings,
75
+ embedding_dim=embedding_dim,
76
+ padding_idx=padding_idx,
77
+ interpolation_mode=interpolation_mode,
78
+ *args,
79
+ **kwargs,
80
+ )
81
+
82
+ def forward(self, seq_len: int, *args, **kwargs) -> Tensor:
83
+ return self.pos_embed(seq_len, *args, **kwargs)
84
+
85
+ def __repr__(self):
86
+ return self.pos_embed.__repr__()
87
+
88
+
89
+ class LearnablePositionalEmbedding(nn.Module):
90
+ """Learnable Positional embedding"""
91
+
92
+ def __init__(
93
+ self,
94
+ num_embeddings: int,
95
+ embedding_dim: int,
96
+ padding_idx: Optional[int] = None,
97
+ interpolation_mode: Optional[str] = "bilinear",
98
+ *args,
99
+ **kwargs,
100
+ ):
101
+ super().__init__()
102
+ self.pos_embed = nn.Parameter(torch.empty(1, 1, num_embeddings, embedding_dim))
103
+ self.embedding_dim = embedding_dim
104
+ self.num_embeddings = num_embeddings
105
+ self.padding_idx = padding_idx
106
+ self.interpolation_mode = interpolation_mode
107
+
108
+ self.reset_parameters()
109
+
110
+ def reset_parameters(self) -> None:
111
+ nn.init.trunc_normal_(self.pos_embed, mean=0, std=self.embedding_dim**-0.5)
112
+ if self.padding_idx is not None:
113
+ with torch.no_grad():
114
+ self.pos_embed[:, :, self.padding_idx, ...] = 0.0
115
+
116
+ def forward(self, seq_len: int, *args, **kwargs) -> Tensor:
117
+ # scale pos embedding
118
+ pos_embed = self.pos_embed
119
+ if self.padding_idx is not None:
120
+ with torch.no_grad():
121
+ pos_embed[:, :, self.padding_idx, ...] = 0.0
122
+
123
+ if seq_len != self.num_embeddings:
124
+ pos_embed = F.interpolate(
125
+ pos_embed,
126
+ size=(seq_len, self.embedding_dim),
127
+ mode=self.interpolation_mode,
128
+ )
129
+
130
+ # Input is of the form [Batch, Seq_len, Embedding_dim]
131
+ return pos_embed.reshape(1, seq_len, self.embedding_dim)
132
+
133
+ def __repr__(self):
134
+ return "{}(num_embeddings={}, embedding_dim={}, padding_idx={})".format(
135
+ self.__class__.__name__,
136
+ self.num_embeddings,
137
+ self.embedding_dim,
138
+ self.padding_idx,
139
+ )
140
+
141
+
142
+ class MultiHeadAttention(nn.Module):
143
+ """
144
+ This layer applies a multi-head self- or cross-attention as described in
145
+ `Attention is all you need <https://arxiv.org/abs/1706.03762>`_ paper
146
+
147
+ Args:
148
+ embed_dim (int): :math:`C_{in}` from an expected input of size :math:`(N, S, C_{in})`
149
+ num_heads (int): Number of heads in multi-head attention
150
+ attn_dropout (Optional[float]): Attention dropout. Default: 0.0
151
+ bias (Optional[bool]): Use bias or not. Default: ``True``
152
+
153
+ Shape:
154
+ - Input:
155
+ - Query tensor (x_q) :math:`(N, S, C_{in})` where :math:`N` is batch size, :math:`S` is number of source tokens,
156
+ and :math:`C_{in}` is input embedding dim
157
+ - Optional Key-Value tensor (x_kv) :math:`(N, T, C_{in})` where :math:`T` is number of target tokens
158
+ - Output: same shape as the input
159
+
160
+ """
161
+
162
+ def __init__(
163
+ self,
164
+ embed_dim: int,
165
+ num_heads: int,
166
+ attn_dropout: Optional[float] = 0.0,
167
+ bias: Optional[bool] = True,
168
+ output_dim: Optional[int] = None,
169
+ *args,
170
+ **kwargs,
171
+ ) -> None:
172
+ if output_dim is None:
173
+ output_dim = embed_dim
174
+ super().__init__()
175
+ if embed_dim % num_heads != 0:
176
+ Warning(
177
+ "Embedding dim must be divisible by number of heads in {}. Got: embed_dim={} and num_heads={}".format(
178
+ self.__class__.__name__, embed_dim, num_heads
179
+ )
180
+ )
181
+
182
+ self.qkv_proj = nn.Linear(
183
+ in_features=embed_dim, out_features=3 * embed_dim, bias=bias
184
+ )
185
+
186
+ self.attn_dropout = nn.Dropout(p=attn_dropout)
187
+ self.out_proj = nn.Linear(
188
+ in_features=embed_dim, out_features=output_dim, bias=bias
189
+ )
190
+
191
+ self.head_dim = embed_dim // num_heads
192
+ self.scaling = self.head_dim**-0.5
193
+ self.softmax = nn.Softmax(dim=-1)
194
+ self.num_heads = num_heads
195
+ self.embed_dim = embed_dim
196
+ self.use_separate_proj_weight = embed_dim != output_dim
197
+
198
+ def __repr__(self):
199
+ return "{}(head_dim={}, num_heads={}, attn_dropout={})".format(
200
+ self.__class__.__name__, self.head_dim, self.num_heads, self.attn_dropout.p
201
+ )
202
+
203
+ def _forward_impl(
204
+ self,
205
+ x_q: Tensor,
206
+ x_kv: Optional[Tensor] = None,
207
+ key_padding_mask: Optional[Tensor] = None,
208
+ attn_mask: Optional[Tensor] = None,
209
+ ) -> Tensor:
210
+ # [N, S, C]
211
+ b_sz, S_len, in_channels = x_q.shape
212
+
213
+ if x_kv is None:
214
+ # self-attention
215
+ # [N, S, C] --> [N, S, 3C] --> [N, S, 3, h, c] where C = hc
216
+ qkv = self.qkv_proj(x_q).reshape(b_sz, S_len, 3, self.num_heads, -1)
217
+ # [N, S, 3, h, c] --> [N, h, 3, S, C]
218
+ qkv = qkv.transpose(1, 3).contiguous()
219
+
220
+ # [N, h, 3, S, C] --> [N, h, S, C] x 3
221
+ query, key, value = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
222
+ else:
223
+ T_len = x_kv.shape[1]
224
+
225
+ # cross-attention
226
+ # [N, S, C]
227
+ query = F.linear(
228
+ x_q,
229
+ weight=self.qkv_proj.weight[: self.embed_dim, ...],
230
+ bias=self.qkv_proj.bias[: self.embed_dim]
231
+ if self.qkv_proj.bias is not None
232
+ else None,
233
+ )
234
+ # [N, S, C] --> [N, S, h, c] --> [N, h, S, c]
235
+ query = (
236
+ query.reshape(b_sz, S_len, self.num_heads, self.head_dim)
237
+ .transpose(1, 2)
238
+ .contiguous()
239
+ )
240
+
241
+ # [N, T, C] --> [N, T, 2C]
242
+ kv = F.linear(
243
+ x_kv,
244
+ weight=self.qkv_proj.weight[self.embed_dim :, ...],
245
+ bias=self.qkv_proj.bias[self.embed_dim :]
246
+ if self.qkv_proj.bias is not None
247
+ else None,
248
+ )
249
+ # [N, T, 2C] --> [N, T, 2, h, c]
250
+ kv = kv.reshape(b_sz, T_len, 2, self.num_heads, self.head_dim)
251
+ # [N, T, 2, h, c] --> [N, h, 2, T, c]
252
+ kv = kv.transpose(1, 3).contiguous()
253
+ key, value = kv[:, :, 0], kv[:, :, 1]
254
+
255
+ query = query * self.scaling
256
+
257
+ # [N h, T, c] --> [N, h, c, T]
258
+ key = key.transpose(-1, -2)
259
+
260
+ # QK^T
261
+ # [N, h, S, c] x [N, h, c, T] --> [N, h, S, T]
262
+ attn = torch.matmul(query, key)
263
+
264
+ batch_size, num_heads, num_src_tokens, num_tgt_tokens = attn.shape
265
+ if attn_mask is not None:
266
+ # attn_mask shape should be the same as attn
267
+ assert list(attn_mask.shape) == [
268
+ batch_size,
269
+ num_src_tokens,
270
+ num_tgt_tokens,
271
+ ], "Shape of attention mask should be [{}, {}, {}]. Got: {}".format(
272
+ batch_size, num_src_tokens, num_tgt_tokens, attn_mask.shape
273
+ )
274
+ # [N, S, T] --> [N, 1, S, T]
275
+ attn_mask = attn_mask.unsqueeze(1)
276
+ attn = attn + attn_mask
277
+
278
+ if key_padding_mask is not None:
279
+ # Do not attend to padding positions
280
+ # key padding mask size is [N, T]
281
+ assert key_padding_mask.dim() == 2 and list(key_padding_mask.shape) == [
282
+ batch_size,
283
+ num_tgt_tokens,
284
+ ], "Key_padding_mask should be 2-dimension with shape [{}, {}]. Got: {}".format(
285
+ batch_size, num_tgt_tokens, key_padding_mask.shape
286
+ )
287
+ attn = attn.masked_fill(
288
+ key_padding_mask.unsqueeze(1)
289
+ .unsqueeze(2)
290
+ .to(torch.bool), # [N, T] --> [N, 1, 1, T]
291
+ float("-inf"),
292
+ )
293
+
294
+ attn_dtype = attn.dtype
295
+ attn_as_float = self.softmax(attn.float())
296
+ attn = attn_as_float.to(attn_dtype)
297
+ attn = self.attn_dropout(attn)
298
+
299
+ # weighted sum
300
+ # [N, h, S, T] x [N, h, T, c] --> [N, h, S, c]
301
+ out = torch.matmul(attn, value)
302
+
303
+ # [N, h, S, c] --> [N, S, h, c] --> [N, S, C]
304
+ out = out.transpose(1, 2).reshape(b_sz, S_len, -1)
305
+ out = self.out_proj(out)
306
+
307
+ return out
308
+
309
+ def forward(
310
+ self,
311
+ x_q: Tensor,
312
+ x_kv: Optional[Tensor] = None,
313
+ key_padding_mask: Optional[Tensor] = None,
314
+ attn_mask: Optional[Tensor] = None,
315
+ *args,
316
+ **kwargs,
317
+ ) -> Tensor:
318
+ # [Batch , Sequence, Hidden_dim]
319
+ return self._forward_impl(
320
+ x_q=x_q,
321
+ x_kv=x_kv,
322
+ key_padding_mask=key_padding_mask,
323
+ attn_mask=attn_mask,
324
+ )
325
+
326
+
327
+ class TransformerEncoder(nn.Module):
328
+ """
329
+ This class defines the pre-norm `Transformer encoder <https://arxiv.org/abs/1706.03762>`_
330
+ Args:
331
+ embed_dim: :math:`C_{in}` from an expected input of size :math:`(N, P, C_{in})`.
332
+ ffn_latent_dim: Inner dimension of the FFN.
333
+ num_heads: Number of heads in multi-head attention. Default: 8.
334
+ attn_dropout: Dropout rate for attention in multi-head attention. Default: 0.0
335
+ dropout: Dropout rate. Default: 0.0.
336
+ ffn_dropout: Dropout between FFN layers. Default: 0.0.
337
+ transformer_norm_layer: Normalization layer. Default: layer_norm.
338
+ stochastic_dropout: Stochastic dropout setting. Default: 0.0.
339
+
340
+ Shape:
341
+ - Input: :math:`(N, P, C_{in})` where :math:`N` is batch size, :math:`P` is number of patches,
342
+ and :math:`C_{in}` is input embedding dim
343
+ - Output: same shape as the input
344
+ """
345
+
346
+ def __init__(
347
+ self,
348
+ embed_dim: int,
349
+ ffn_latent_dim: int,
350
+ num_heads: Optional[int] = 8,
351
+ attn_dropout: Optional[float] = 0.0,
352
+ dropout: Optional[float] = 0.0,
353
+ ffn_dropout: Optional[float] = 0.0,
354
+ transformer_norm_layer: Optional[str] = "layer_norm",
355
+ stochastic_dropout: Optional[float] = 0.0,
356
+ *args,
357
+ **kwargs,
358
+ ) -> None:
359
+
360
+ super().__init__()
361
+
362
+ # Build attention layer
363
+ attn_unit = MultiHeadAttention(
364
+ embed_dim,
365
+ num_heads,
366
+ attn_dropout=attn_dropout,
367
+ bias=True,
368
+ )
369
+
370
+ self.pre_norm_mha = nn.Sequential(
371
+ get_normalization_layer(
372
+ norm_type=transformer_norm_layer, num_features=embed_dim
373
+ ),
374
+ attn_unit,
375
+ nn.Dropout(p=dropout),
376
+ )
377
+
378
+ act_name = nn.GELU()
379
+ self.pre_norm_ffn = nn.Sequential(
380
+ get_normalization_layer(
381
+ norm_type=transformer_norm_layer, num_features=embed_dim
382
+ ),
383
+ nn.Linear(in_features=embed_dim, out_features=ffn_latent_dim, bias=True),
384
+ act_name,
385
+ nn.Dropout(p=ffn_dropout),
386
+ nn.Linear(in_features=ffn_latent_dim, out_features=embed_dim, bias=True),
387
+ nn.Dropout(p=dropout),
388
+ )
389
+
390
+ self.drop_path = nn.Identity()
391
+ if stochastic_dropout > 0.0:
392
+ if dropout > 0.0:
393
+ Warning(
394
+ "Stochastic dropout and dropout are mutually exclusive. "
395
+ "Use either of them, but not both."
396
+ "Got: {} and {}".format(stochastic_dropout, dropout)
397
+ )
398
+ self.drop_path = StochasticDepth(p=stochastic_dropout, mode="row")
399
+
400
+ self.embed_dim = embed_dim
401
+ self.ffn_dim = ffn_latent_dim
402
+ self.ffn_dropout = ffn_dropout
403
+ self.stochastic_dropout = stochastic_dropout
404
+ self.std_dropout = dropout
405
+ self.attn_fn_name = attn_unit.__class__.__name__
406
+ self.act_fn_name = act_name.__class__.__name__
407
+ self.norm_type = transformer_norm_layer
408
+
409
+ def __repr__(self) -> str:
410
+ return "{}(embed_dim={}, ffn_dim={}, dropout={}, ffn_dropout={}, stochastic_dropout={}, attn_fn={}, act_fn={}, norm_fn={})".format(
411
+ self.__class__.__name__,
412
+ self.embed_dim,
413
+ self.ffn_dim,
414
+ self.std_dropout,
415
+ self.ffn_dropout,
416
+ self.stochastic_dropout,
417
+ self.attn_fn_name,
418
+ self.act_fn_name,
419
+ self.norm_type,
420
+ )
421
+
422
+ def forward(
423
+ self,
424
+ x: Tensor,
425
+ x_prev: Optional[Tensor] = None,
426
+ key_padding_mask: Optional[Tensor] = None,
427
+ attn_mask: Optional[Tensor] = None,
428
+ *args,
429
+ **kwargs,
430
+ ) -> Tensor:
431
+
432
+ # Multi-head attention
433
+ res = x
434
+ x = self.pre_norm_mha[0](x) # norm
435
+ x = self.pre_norm_mha[1](
436
+ x_q=x,
437
+ x_kv=x_prev,
438
+ key_padding_mask=key_padding_mask,
439
+ attn_mask=attn_mask,
440
+ *args,
441
+ **kwargs,
442
+ ) # mha
443
+
444
+ x = self.drop_path(self.pre_norm_mha[2](x)) # applying stochastic depth
445
+ x = x + res
446
+
447
+ # Feed forward network
448
+ x = x + self.drop_path(self.pre_norm_ffn(x))
449
+ return x
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:136b42a078a8ec440e38b56d91d570fb0969643a641795e06e171162ab176b4e
3
+ size 745562274
modeling_internvideo2encoder.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from .internvideo2_stage2 import InternVideo2_Stage2 as IV2S2
2
+ from transformers import PretrainedConfig, PreTrainedModel, AutoModel, AutoConfig
3
+ from .config import InternVideo2Config as config
4
+ import warnings
5
+ import torch
6
+ from torch import nn
7
+ import torchvision.transforms as transforms
8
+ from torchvision.transforms import InterpolationMode
9
+ from transformers.utils import logging
10
+ warnings.filterwarnings("ignore")
11
+ from .internvideo2_clip_vision import InternVideo2
12
+ from .mobile_clip import TextTransformer, ClipTokenizer
13
+ logger = logging.get_logger(__name__)
14
+
15
+ class InternVideo2_CLIP_small(PreTrainedModel):
16
+ config_class = config
17
+
18
+ def __init__(self, config, tokenizer=None, is_pretrain=True):
19
+ super().__init__(config)
20
+ self.config = config
21
+ self.tokenizer = tokenizer
22
+ self.is_pretrain = is_pretrain
23
+ print(config)
24
+ if tokenizer is None:
25
+ self.tokenizer = ClipTokenizer(self.config.model.text_encoder)
26
+ # self.model = IV2S2(self.config).to('cpu').to(torch.float16)
27
+ self.vision_encoder = self.build_vision_encoder()
28
+
29
+ self.vision_align = nn.Sequential(
30
+ nn.LayerNorm(self.config.model.vision_encoder.clip_embed_dim),
31
+ nn.Linear(
32
+ self.config.model.vision_encoder.clip_embed_dim,
33
+ self.config.model.vision_encoder.align_dim
34
+ ),
35
+ )
36
+ self.text_encoder = self.build_text_encoder(cfg=self.config.model.text_encoder['text_cfg'], projection_dim=self.config.model.text_encoder["embed_dim"])
37
+ # adopt 1 / 100. as in ViCLIP
38
+ self.temp = nn.parameter.Parameter(torch.ones([]) * config.model.temp)
39
+ self.temp_min = config.model.temp_min
40
+
41
+ if self.config.model.freeze_vision:
42
+ for name, p in self.vision_encoder.named_parameters():
43
+ if self.config.model.open_vision_clip_projector and name.startswith('clip_projector'):
44
+ logger.info(f"Unfreeze {name}")
45
+ else:
46
+ logger.info(f"Freeze {name}")
47
+ p.requires_grad = False
48
+ if self.config.model.freeze_text:
49
+ for name, p in self.text_encoder.named_parameters():
50
+ if self.config.model.open_text_projection and name.startswith('projection_layer'):
51
+ logger.info(f"Unfreeze {name}")
52
+ else:
53
+ logger.info(f"Freeze {name}")
54
+ p.requires_grad = False
55
+ img_size = self.config.model.vision_encoder.img_size
56
+ self.transform = transforms.Compose(
57
+ [
58
+ transforms.Resize(
59
+ (img_size, img_size),
60
+ interpolation=InterpolationMode.BICUBIC,
61
+ ),
62
+ transforms.Lambda(lambda x: x.float().div(255.0)),
63
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
64
+ ]
65
+ )
66
+
67
+
68
+ @torch.no_grad()
69
+ def clip_contrastive_temperature(self):
70
+ """Seems only used during pre-training"""
71
+ self.temp.clamp_(min=self.temp_min)
72
+
73
+ def encode_vision(self, image, test=False):
74
+ """encode image / videos as features.
75
+
76
+ Args:
77
+ image (torch.Tensor): The input images.
78
+ test (bool): Whether testing.
79
+
80
+ Returns: tuple.
81
+ - vision_embeds (torch.Tensor): The features of all patches. Shape: [B,C].
82
+
83
+ """
84
+ T = image.shape[1]
85
+ use_image = True if T == 1 else False
86
+ image = image.permute(0, 2, 1, 3, 4) # [B,T,C,H,W] -> [B,C,T,H,W]
87
+
88
+ vision_embeds = self.vision_encoder(image, use_image=use_image)
89
+ vision_embeds = self.vision_align(vision_embeds)
90
+ return vision_embeds
91
+
92
+ def encode_text(self, text):
93
+ """encode text.
94
+ Args:
95
+ text (dict): The output of huggingface's `PreTrainedTokenizer`. contains keys:
96
+ - input_ids (torch.Tensor): Token ids to be fed to a model. Shape: [B,L].
97
+ - attention_mask (torch.Tensor): The mask indicate padded tokens. Shape: [B,L]. 0 is padded token.
98
+ - other keys refer to "https://huggingface.co/docs/transformers/v4.21.2/en/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__".
99
+ Returns: tuple.
100
+ - text_embeds (torch.Tensor): The features of all tokens. Shape: [B,C].
101
+
102
+ """
103
+ text_embeds = self.text_encoder(text)
104
+ return text_embeds
105
+
106
+ def build_vision_encoder(self):
107
+ """build vision encoder
108
+ Returns: (vision_encoder, vision_layernorm). Each is a `nn.Module`.
109
+
110
+ """
111
+ vision_encoder = InternVideo2(
112
+ in_chans=self.config.model.vision_encoder.in_chans,
113
+ patch_size=self.config.model.vision_encoder.patch_size,
114
+ img_size=self.config.model.vision_encoder.img_size,
115
+ qkv_bias=self.config.model.vision_encoder.qkv_bias,
116
+ drop_path_rate=self.config.model.vision_encoder.drop_path_rate,
117
+ head_drop_path_rate=self.config.model.vision_encoder.head_drop_path_rate,
118
+ embed_dim=self.config.model.vision_encoder.embed_dim,
119
+ num_heads=self.config.model.vision_encoder.num_heads,
120
+ mlp_ratio=self.config.model.vision_encoder.mlp_ratio,
121
+ init_values=self.config.model.vision_encoder.init_values,
122
+ qk_normalization=self.config.model.vision_encoder.qk_normalization,
123
+ depth=self.config.model.vision_encoder.depth,
124
+ use_flash_attn=self.config.model.vision_encoder.use_flash_attn,
125
+ use_fused_rmsnorm=self.config.model.vision_encoder.use_fused_rmsnorm,
126
+ use_fused_mlp=self.config.model.vision_encoder.use_fused_mlp,
127
+ fused_mlp_heuristic=self.config.model.vision_encoder.fused_mlp_heuristic,
128
+ attn_pool_num_heads=self.config.model.vision_encoder.attn_pool_num_heads,
129
+ clip_embed_dim=self.config.model.vision_encoder.clip_embed_dim,
130
+ layerscale_no_force_fp32=self.config.model.vision_encoder.layerscale_no_force_fp32,
131
+ num_frames=self.config.model.vision_encoder.num_frames,
132
+ tubelet_size=self.config.model.vision_encoder.tubelet_size,
133
+ sep_pos_embed=self.config.model.vision_encoder.sep_pos_embed,
134
+ use_checkpoint=self.config.model.vision_encoder.use_checkpoint,
135
+ checkpoint_num=self.config.model.vision_encoder.checkpoint_num,
136
+ )
137
+ return vision_encoder
138
+
139
+ def build_text_encoder(self, cfg, projection_dim):
140
+ """build text_encoder and possiblly video-to-text multimodal fusion encoder.
141
+ Returns: nn.Module. The text encoder
142
+
143
+ """
144
+ text_encoder = TextTransformer(cfg, projection_dim)
145
+
146
+ return text_encoder
147
+
148
+ if __name__ == "__main__":
149
+ model_config = config()
150
+ model = InternVideo2Stage2VideoEncoder(model_config)
151
+ x = torch.randn(2, 3, 8, 224, 224, dtype=torch.float16).to(model_config.device)
152
+ output = model(x)
pos_embed.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import logging
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+ # --------------------------------------------------------
8
+ # 3D sine-cosine position embedding
9
+ # References:
10
+ # MVD: https://github.com/ruiwang2021/mvd/blob/main/modeling_finetune.py
11
+ # --------------------------------------------------------
12
+ def get_3d_sincos_pos_embed(embed_dim, grid_size, t_size, cls_token=False):
13
+ """
14
+ grid_size: int of the grid height and width
15
+ t_size: int of the temporal size
16
+ return:
17
+ pos_embed: [t_size*grid_size*grid_size, embed_dim] or [1+t_size*grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
18
+ """
19
+ assert embed_dim % 4 == 0
20
+ embed_dim_spatial = embed_dim // 4 * 3
21
+ embed_dim_temporal = embed_dim // 4
22
+
23
+ # spatial
24
+ grid_h = np.arange(grid_size, dtype=np.float32)
25
+ grid_w = np.arange(grid_size, dtype=np.float32)
26
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
27
+ grid = np.stack(grid, axis=0)
28
+
29
+ grid = grid.reshape([2, 1, grid_size, grid_size])
30
+ pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(
31
+ embed_dim_spatial, grid
32
+ )
33
+
34
+ # temporal
35
+ grid_t = np.arange(t_size, dtype=np.float32)
36
+ pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(
37
+ embed_dim_temporal, grid_t
38
+ )
39
+
40
+ # concate: [T, H, W] order
41
+ pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :]
42
+ pos_embed_temporal = np.repeat(
43
+ pos_embed_temporal, grid_size**2, axis=1
44
+ ) # [T, H*W, D // 4]
45
+ pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :]
46
+ pos_embed_spatial = np.repeat(
47
+ pos_embed_spatial, t_size, axis=0
48
+ ) # [T, H*W, D // 4 * 3]
49
+
50
+ pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1)
51
+ pos_embed = pos_embed.reshape([-1, embed_dim]) # [T*H*W, D]
52
+
53
+ if cls_token:
54
+ pos_embed = np.concatenate(
55
+ [np.zeros([1, embed_dim]), pos_embed], axis=0
56
+ )
57
+ return pos_embed
58
+
59
+
60
+ # --------------------------------------------------------
61
+ # 2D sine-cosine position embedding
62
+ # References:
63
+ # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
64
+ # MoCo v3: https://github.com/facebookresearch/moco-v3
65
+ # --------------------------------------------------------
66
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
67
+ """
68
+ grid_size: int of the grid height and width
69
+ return:
70
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
71
+ """
72
+ grid_h = np.arange(grid_size, dtype=np.float32)
73
+ grid_w = np.arange(grid_size, dtype=np.float32)
74
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
75
+ grid = np.stack(grid, axis=0)
76
+
77
+ grid = grid.reshape([2, 1, grid_size, grid_size])
78
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
79
+ if cls_token:
80
+ pos_embed = np.concatenate(
81
+ [np.zeros([1, embed_dim]), pos_embed], axis=0
82
+ )
83
+ return pos_embed
84
+
85
+
86
+ def get_1d_sincos_pos_embed(embed_dim, t_size, cls_token=False):
87
+ """
88
+ t_size: int of the temporal size
89
+ return:
90
+ pos_embed: [t_size, embed_dim] or [1+t_size, embed_dim] (w/ or w/o cls_token)
91
+ """
92
+ grid_t = np.arange(t_size, dtype=np.float32)
93
+ pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid_t)
94
+ if cls_token:
95
+ pos_embed = np.concatenate(
96
+ [np.zeros([1, embed_dim]), pos_embed], axis=0
97
+ )
98
+ return pos_embed
99
+
100
+
101
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
102
+ assert embed_dim % 2 == 0
103
+
104
+ # use half of dimensions to encode grid_h
105
+ emb_h = get_1d_sincos_pos_embed_from_grid(
106
+ embed_dim // 2, grid[0]
107
+ ) # (H*W, D/2)
108
+ emb_w = get_1d_sincos_pos_embed_from_grid(
109
+ embed_dim // 2, grid[1]
110
+ ) # (H*W, D/2)
111
+
112
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
113
+ return emb
114
+
115
+
116
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
117
+ """
118
+ embed_dim: output dimension for each position
119
+ pos: a list of positions to be encoded: size (M,)
120
+ out: (M, D)
121
+ """
122
+ assert embed_dim % 2 == 0
123
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
124
+ omega /= embed_dim / 2.0
125
+ omega = 1.0 / 10000**omega # (D/2,)
126
+
127
+ pos = pos.reshape(-1) # (M,)
128
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
129
+
130
+ emb_sin = np.sin(out) # (M, D/2)
131
+ emb_cos = np.cos(out) # (M, D/2)
132
+
133
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
134
+ return emb
135
+
136
+
137
+ def interpolate_pos_embed(checkpoint_model, model, orig_t_size=4, pos_name='vision_encoder.pos_embed'):
138
+ if pos_name in checkpoint_model:
139
+ pos_embed_checkpoint = checkpoint_model[pos_name]
140
+ embedding_size = pos_embed_checkpoint.shape[-1] # channel dim
141
+ num_patches = model.patch_embed.num_patches #
142
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches # 0/1
143
+
144
+ # we use 4 frames for pretraining
145
+ new_t_size = model.T
146
+ # height (== width) for the checkpoint position embedding
147
+ orig_size = int(((pos_embed_checkpoint.shape[-2] - num_extra_tokens)//(orig_t_size)) ** 0.5)
148
+ # height (== width) for the new position embedding
149
+ new_size = int((num_patches // (new_t_size))** 0.5)
150
+
151
+ # class_token and dist_token are kept unchanged
152
+ if orig_t_size != new_t_size:
153
+ logger.info(f"Temporal interpolate from {orig_t_size} to {new_t_size} ({pos_name})")
154
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
155
+ # only the position tokens are interpolated
156
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
157
+ # B, L, C -> B, T, HW, C -> BHW, C, T (B = 1)
158
+ pos_tokens = pos_tokens.view(1, orig_t_size, -1, embedding_size)
159
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, embedding_size, orig_t_size)
160
+ pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=new_t_size, mode='linear')
161
+ pos_tokens = pos_tokens.view(1, -1, embedding_size, new_t_size)
162
+ pos_tokens = pos_tokens.permute(0, 3, 1, 2).reshape(1, -1, embedding_size)
163
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
164
+ checkpoint_model[pos_name] = new_pos_embed
165
+ pos_embed_checkpoint = new_pos_embed
166
+
167
+ # class_token and dist_token are kept unchanged
168
+ if orig_size != new_size:
169
+ logger.info(f"Position interpolate from {orig_size}x{orig_size} to {new_size}x{new_size} ({pos_name})")
170
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
171
+ # only the position tokens are interpolated
172
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
173
+ # B, L, C -> BT, H, W, C -> BT, C, H, W
174
+ pos_tokens = pos_tokens.reshape(-1, new_t_size, orig_size, orig_size, embedding_size)
175
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
176
+ pos_tokens = torch.nn.functional.interpolate(
177
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
178
+ # BT, C, H, W -> BT, H, W, C -> B, T, H, W, C
179
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, new_t_size, new_size, new_size, embedding_size)
180
+ pos_tokens = pos_tokens.flatten(1, 3) # B, L, C
181
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
182
+ checkpoint_model[pos_name] = new_pos_embed
183
+
184
+
185
+ def interpolate_pos_embed_internvideo2(checkpoint_model, model, orig_t_size = 8):
186
+ # interpolate position embedding
187
+ for pos_name in ['pos_embed', 'clip_pos_embed']:
188
+ if pos_name in checkpoint_model:
189
+ pos_embed_checkpoint = checkpoint_model[pos_name]
190
+ embedding_size = pos_embed_checkpoint.shape[-1] # channel dim
191
+ num_patches = model.patch_embed.num_patches #
192
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches # 0/1
193
+
194
+ # we use 8 frames for pretraining
195
+ # new_t_size = args.num_frames * args.num_segments // model.patch_embed.tubelet_size
196
+ new_t_size = model.num_frames // model.tubelet_size
197
+ # height (== width) for the checkpoint position embedding
198
+ orig_size = int(((pos_embed_checkpoint.shape[-2] - num_extra_tokens)//(orig_t_size)) ** 0.5)
199
+ # height (== width) for the new position embedding
200
+ new_size = int((num_patches // (new_t_size))** 0.5)
201
+
202
+ # class_token and dist_token are kept unchanged
203
+ if orig_t_size != new_t_size:
204
+ logger.info(f"Temporal interpolate from {orig_t_size} to {new_t_size} ({pos_name})")
205
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
206
+ # only the position tokens are interpolated
207
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
208
+ # B, L, C -> B, T, HW, C -> BHW, C, T (B = 1)
209
+ pos_tokens = pos_tokens.view(1, orig_t_size, -1, embedding_size)
210
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, embedding_size, orig_t_size)
211
+ pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=new_t_size, mode='linear')
212
+ pos_tokens = pos_tokens.view(1, -1, embedding_size, new_t_size)
213
+ pos_tokens = pos_tokens.permute(0, 3, 1, 2).reshape(1, -1, embedding_size)
214
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
215
+ checkpoint_model[pos_name] = new_pos_embed
216
+ pos_embed_checkpoint = new_pos_embed
217
+
218
+ # class_token and dist_token are kept unchanged
219
+ if orig_size != new_size:
220
+ logger.info(f"Position interpolate from {orig_size}x{orig_size} to {new_size}x{new_size} ({pos_name})")
221
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
222
+ # only the position tokens are interpolated
223
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
224
+ # B, L, C -> BT, H, W, C -> BT, C, H, W
225
+ pos_tokens = pos_tokens.reshape(-1, new_t_size, orig_size, orig_size, embedding_size)
226
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
227
+ pos_tokens = torch.nn.functional.interpolate(
228
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
229
+ # BT, C, H, W -> BT, H, W, C -> B, T, H, W, C
230
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, new_t_size, new_size, new_size, embedding_size)
231
+ pos_tokens = pos_tokens.flatten(1, 3) # B, L, C
232
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
233
+ checkpoint_model[pos_name] = new_pos_embed
234
+
235
+ if 'pos_embed_spatial' in checkpoint_model or 'pos_embed_temporal' in checkpoint_model:
236
+ raise NotImplementedError
237
+
238
+
239
+ def interpolate_pos_embed_internvideo2_new(checkpoint_model, model, orig_t_size = 8):
240
+ pos_names = []
241
+ for k in checkpoint_model.keys():
242
+ if ('pos_embed' in k or 'clip_pos_embed' in k) and 'img_pos_embed' not in k:
243
+ pos_names.append(k)
244
+
245
+ logger.info(f"pos names list for interpolating: {pos_names}")
246
+
247
+ assert len(pos_names) > 0, checkpoint_model.keys()
248
+
249
+ if 'pos_embed_spatial' in checkpoint_model.keys() or 'pos_embed_temporal' in checkpoint_model.keys():
250
+ raise NotImplementedError
251
+
252
+ # interpolate position embedding
253
+ for pos_name in pos_names:
254
+
255
+ pos_embed_checkpoint = checkpoint_model[pos_name]
256
+ embedding_size = pos_embed_checkpoint.shape[-1] # channel dim
257
+ num_patches = model.patch_embed.num_patches #
258
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches # 0/1
259
+
260
+ # we use 8 frames for pretraining
261
+ # new_t_size = args.num_frames * args.num_segments // model.patch_embed.tubelet_size
262
+ new_t_size = model.num_frames // model.tubelet_size
263
+ # height (== width) for the checkpoint position embedding
264
+ orig_size = int(((pos_embed_checkpoint.shape[-2] - num_extra_tokens)//(orig_t_size)) ** 0.5)
265
+ # height (== width) for the new position embedding
266
+ new_size = int((num_patches // (new_t_size))** 0.5)
267
+
268
+ # class_token and dist_token are kept unchanged
269
+ if orig_t_size != new_t_size:
270
+ logger.info(f"Temporal interpolate from {orig_t_size} to {new_t_size} ({pos_name})")
271
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
272
+ # only the position tokens are interpolated
273
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
274
+ # B, L, C -> B, T, HW, C -> BHW, C, T (B = 1)
275
+ pos_tokens = pos_tokens.view(1, orig_t_size, -1, embedding_size)
276
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, embedding_size, orig_t_size)
277
+ pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=new_t_size, mode='linear')
278
+ pos_tokens = pos_tokens.view(1, -1, embedding_size, new_t_size)
279
+ pos_tokens = pos_tokens.permute(0, 3, 1, 2).reshape(1, -1, embedding_size)
280
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
281
+ checkpoint_model[pos_name] = new_pos_embed
282
+ pos_embed_checkpoint = new_pos_embed
283
+
284
+ # class_token and dist_token are kept unchanged
285
+ if orig_size != new_size:
286
+ logger.info(f"Position interpolate from {orig_size}x{orig_size} to {new_size}x{new_size} ({pos_name})")
287
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
288
+ # only the position tokens are interpolated
289
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
290
+ # B, L, C -> BT, H, W, C -> BT, C, H, W
291
+ pos_tokens = pos_tokens.reshape(-1, new_t_size, orig_size, orig_size, embedding_size)
292
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
293
+ pos_tokens = torch.nn.functional.interpolate(
294
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
295
+ # BT, C, H, W -> BT, H, W, C -> B, T, H, W, C
296
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, new_t_size, new_size, new_size, embedding_size)
297
+ pos_tokens = pos_tokens.flatten(1, 3) # B, L, C
298
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
299
+ checkpoint_model[pos_name] = new_pos_embed
test.ipynb ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {
7
+ "metadata": {}
8
+ },
9
+ "outputs": [
10
+ {
11
+ "name": "stderr",
12
+ "output_type": "stream",
13
+ "text": [
14
+ "/root/miniconda3/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
15
+ " from .autonotebook import tqdm as notebook_tqdm\n"
16
+ ]
17
+ },
18
+ {
19
+ "name": "stdout",
20
+ "output_type": "stream",
21
+ "text": [
22
+ "InternVideo2Config {\n",
23
+ " \"_attn_implementation_autoset\": true,\n",
24
+ " \"architectures\": [\n",
25
+ " \"InternVideo2_CLIP_small\"\n",
26
+ " ],\n",
27
+ " \"auto_map\": {\n",
28
+ " \"AutoConfig\": \"config.InternVideo2Config\",\n",
29
+ " \"AutoModel\": \"modeling_internvideo2encoder.InternVideo2_CLIP_small\"\n",
30
+ " },\n",
31
+ " \"auto_resume\": false,\n",
32
+ " \"batch_size\": 64,\n",
33
+ " \"batch_size_test\": 4,\n",
34
+ " \"best_key\": [\n",
35
+ " \"msrvtt_1k_test_match\",\n",
36
+ " \"t2v_r1\"\n",
37
+ " ],\n",
38
+ " \"compile_model\": false,\n",
39
+ " \"criterion\": {\n",
40
+ " \"clip_loss_ratio\": [\n",
41
+ " 1.0,\n",
42
+ " 1.0\n",
43
+ " ],\n",
44
+ " \"distill_final_features\": true,\n",
45
+ " \"loss_weight\": {\n",
46
+ " \"mlm\": 1.0,\n",
47
+ " \"mvm\": 0.0,\n",
48
+ " \"uta\": 0.0,\n",
49
+ " \"vtc\": 1.0,\n",
50
+ " \"vtm\": 1.0\n",
51
+ " },\n",
52
+ " \"mlm_masking_prob\": 0.5,\n",
53
+ " \"vtm_hard_neg\": true\n",
54
+ " },\n",
55
+ " \"debug\": false,\n",
56
+ " \"deep_fusion\": false,\n",
57
+ " \"deepspeed\": {\n",
58
+ " \"enable\": true,\n",
59
+ " \"stage\": 1\n",
60
+ " },\n",
61
+ " \"delete_ds_optim_states\": true,\n",
62
+ " \"device\": \"cuda\",\n",
63
+ " \"dist_url\": \"env://\",\n",
64
+ " \"evaluate\": false,\n",
65
+ " \"evaluation\": {\n",
66
+ " \"eval_frame_ensemble\": \"concat\",\n",
67
+ " \"eval_offload\": true,\n",
68
+ " \"eval_x_only\": false,\n",
69
+ " \"k_test\": 128\n",
70
+ " },\n",
71
+ " \"gradient_checkpointing\": true,\n",
72
+ " \"inputs\": {\n",
73
+ " \"batch_size\": {\n",
74
+ " \"image\": 64,\n",
75
+ " \"video\": 64\n",
76
+ " },\n",
77
+ " \"batch_size_test\": {\n",
78
+ " \"image\": 4,\n",
79
+ " \"video\": 4\n",
80
+ " },\n",
81
+ " \"image_res\": 224,\n",
82
+ " \"max_txt_l\": {\n",
83
+ " \"image\": 32,\n",
84
+ " \"video\": 32\n",
85
+ " },\n",
86
+ " \"video_input\": {\n",
87
+ " \"num_frames\": 8,\n",
88
+ " \"num_frames_test\": 8,\n",
89
+ " \"random_aug\": false,\n",
90
+ " \"sample_type\": \"middle\",\n",
91
+ " \"sample_type_test\": \"middle\"\n",
92
+ " }\n",
93
+ " },\n",
94
+ " \"jump_evaluate\": false,\n",
95
+ " \"log_freq\": 100,\n",
96
+ " \"max_txt_l\": 32,\n",
97
+ " \"mode\": \"pt\",\n",
98
+ " \"model\": {\n",
99
+ " \"embed_dim\": 1024,\n",
100
+ " \"find_unused_parameters\": false,\n",
101
+ " \"freeze_text\": true,\n",
102
+ " \"freeze_vision\": true,\n",
103
+ " \"load_vision_ckpt_from_internvideo2_stage2\": false,\n",
104
+ " \"model_cls\": \"InternVideo2_CLIP_small\",\n",
105
+ " \"multimodal\": {\n",
106
+ " \"enable\": true\n",
107
+ " },\n",
108
+ " \"open_text_projection\": false,\n",
109
+ " \"open_vision_clip_projector\": true,\n",
110
+ " \"temp\": 0.01,\n",
111
+ " \"temp_min\": 0.01,\n",
112
+ " \"text_encoder\": {\n",
113
+ " \"embed_dim\": 512,\n",
114
+ " \"image_cfg\": {\n",
115
+ " \"image_size\": 224,\n",
116
+ " \"model_name\": \"vit_b16\"\n",
117
+ " },\n",
118
+ " \"text_cfg\": {\n",
119
+ " \"causal_masking\": true,\n",
120
+ " \"context_length\": 77,\n",
121
+ " \"dim\": 512,\n",
122
+ " \"ffn_multiplier_per_layer\": 4.0,\n",
123
+ " \"model_name\": \"base\",\n",
124
+ " \"n_heads_per_layer\": 8,\n",
125
+ " \"n_transformer_layers\": 12,\n",
126
+ " \"norm_layer\": \"layer_norm_fp32\",\n",
127
+ " \"vocab_size\": 49408\n",
128
+ " }\n",
129
+ " },\n",
130
+ " \"vision_encoder\": {\n",
131
+ " \"align_dim\": 512,\n",
132
+ " \"attn_pool_num_heads\": 16,\n",
133
+ " \"checkpoint_num\": 0,\n",
134
+ " \"clip_embed_dim\": 768,\n",
135
+ " \"depth\": 24,\n",
136
+ " \"drop_cls_token\": false,\n",
137
+ " \"drop_path_rate\": 0.0,\n",
138
+ " \"embed_dim\": 1024,\n",
139
+ " \"fused_mlp_heuristic\": 1,\n",
140
+ " \"head_drop_path_rate\": 0.0,\n",
141
+ " \"img_size\": 224,\n",
142
+ " \"in_chans\": 3,\n",
143
+ " \"init_values\": 0.1,\n",
144
+ " \"layerscale_no_force_fp32\": true,\n",
145
+ " \"mlp_ratio\": 4,\n",
146
+ " \"name\": \"internvideo2_1B\",\n",
147
+ " \"num_frames\": 8,\n",
148
+ " \"num_heads\": 16,\n",
149
+ " \"patch_size\": 14,\n",
150
+ " \"qk_normalization\": true,\n",
151
+ " \"qkv_bias\": false,\n",
152
+ " \"sep_pos_embed\": false,\n",
153
+ " \"tubelet_size\": 1,\n",
154
+ " \"use_checkpoint\": false,\n",
155
+ " \"use_flash_attn\": false,\n",
156
+ " \"use_fused_mlp\": false,\n",
157
+ " \"use_fused_rmsnorm\": false\n",
158
+ " }\n",
159
+ " },\n",
160
+ " \"model_type\": \"internvideo2\",\n",
161
+ " \"num_frames\": 8,\n",
162
+ " \"num_frames_test\": 8,\n",
163
+ " \"num_workers\": 6,\n",
164
+ " \"optimizer\": {\n",
165
+ " \"different_lr\": {\n",
166
+ " \"enable\": false,\n",
167
+ " \"lr\": 0.001,\n",
168
+ " \"module_names\": []\n",
169
+ " },\n",
170
+ " \"lr\": 5e-05,\n",
171
+ " \"max_grad_norm\": 3.0,\n",
172
+ " \"opt\": \"adamW\",\n",
173
+ " \"opt_betas\": [\n",
174
+ " 0.9,\n",
175
+ " 0.98\n",
176
+ " ],\n",
177
+ " \"weight_decay\": 0.05\n",
178
+ " },\n",
179
+ " \"output_dir\": null,\n",
180
+ " \"pretrained_path\": \"\",\n",
181
+ " \"resume\": false,\n",
182
+ " \"save_ckpt_iter\": null,\n",
183
+ " \"save_latest\": true,\n",
184
+ " \"scheduler\": {\n",
185
+ " \"epochs\": 10,\n",
186
+ " \"min_lr_multi\": 0.01,\n",
187
+ " \"sched\": \"cosine\",\n",
188
+ " \"warmup_epochs\": 1\n",
189
+ " },\n",
190
+ " \"seed\": 42,\n",
191
+ " \"test_file\": {\n",
192
+ " \"didemo_ret_test\": \"available_corpus[\\\"didemo_ret_test\\\"]\",\n",
193
+ " \"msrvtt_1k_test\": \"available_corpus[\\\"msrvtt_1k_test\\\"]\"\n",
194
+ " },\n",
195
+ " \"test_types\": [\n",
196
+ " \"msrvtt_1k_test\",\n",
197
+ " \"didemo_ret_test\"\n",
198
+ " ],\n",
199
+ " \"text_enc\": \"bert_large\",\n",
200
+ " \"tokenizer\": null,\n",
201
+ " \"torch_dtype\": \"float32\",\n",
202
+ " \"train_file\": \"available_corpus[\\\"pretrain_example_data_1B\\\"]\",\n",
203
+ " \"transformers_version\": \"4.51.3\",\n",
204
+ " \"use_bf16\": true,\n",
205
+ " \"use_flash_sdp\": false,\n",
206
+ " \"use_half_precision\": false,\n",
207
+ " \"use_mem_efficient_sdp\": false,\n",
208
+ " \"wandb\": {\n",
209
+ " \"enable\": false,\n",
210
+ " \"entity\": \"opengvlab\",\n",
211
+ " \"project\": \"InternVideo2-Stage2\"\n",
212
+ " }\n",
213
+ "}\n",
214
+ "\n"
215
+ ]
216
+ }
217
+ ],
218
+ "source": [
219
+ "from transformers import AutoConfig, AutoModel\n",
220
+ "config = AutoConfig.from_pretrained(\"/fs-computility/video/heyinan/iv2hf/\", trust_remote_code=True)\n",
221
+ "model = AutoModel.from_pretrained(\"/fs-computility/video/heyinan/iv2hf/\", trust_remote_code=True).to(config.device)"
222
+ ]
223
+ },
224
+ {
225
+ "cell_type": "code",
226
+ "execution_count": 2,
227
+ "metadata": {
228
+ "metadata": {}
229
+ },
230
+ "outputs": [],
231
+ "source": [
232
+ "import os\n",
233
+ "import random\n",
234
+ "import io\n",
235
+ "import av\n",
236
+ "import cv2\n",
237
+ "import decord\n",
238
+ "import imageio\n",
239
+ "from decord import VideoReader\n",
240
+ "import torch\n",
241
+ "import numpy as np\n",
242
+ "import math\n",
243
+ "import torch.nn.functional as F\n",
244
+ "decord.bridge.set_bridge(\"torch\")\n",
245
+ "\n",
246
+ "\n",
247
+ "def get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1, start=None, end=None):\n",
248
+ " start_frame, end_frame = 0, vlen\n",
249
+ " if start is not None:\n",
250
+ " start_frame = max(start_frame,int(start * input_fps))\n",
251
+ " if end is not None:\n",
252
+ " end_frame = min(end_frame,int(end * input_fps))\n",
253
+ "\n",
254
+ " # Ensure start_frame is less than end_frame\n",
255
+ " if start_frame >= end_frame:\n",
256
+ " raise ValueError(\"Start frame index must be less than end frame index\")\n",
257
+ "\n",
258
+ " # Calculate the length of the clip in frames\n",
259
+ " clip_length = end_frame - start_frame\n",
260
+ "\n",
261
+ " if sample in [\"rand\", \"middle\"]: # uniform sampling\n",
262
+ " acc_samples = min(num_frames, clip_length)\n",
263
+ " # split the clip into `acc_samples` intervals, and sample from each interval.\n",
264
+ " intervals = np.linspace(start=start_frame, stop=end_frame, num=acc_samples + 1).astype(int)\n",
265
+ " ranges = []\n",
266
+ " for idx, interv in enumerate(intervals[:-1]):\n",
267
+ " ranges.append((interv, intervals[idx + 1] - 1))\n",
268
+ " if sample == 'rand':\n",
269
+ " try:\n",
270
+ " frame_indices = [random.choice(range(x[0], x[1] + 1)) for x in ranges]\n",
271
+ " except:\n",
272
+ " frame_indices = np.random.permutation(clip_length)[:acc_samples] + start_frame\n",
273
+ " frame_indices.sort()\n",
274
+ " frame_indices = list(frame_indices)\n",
275
+ " elif fix_start is not None:\n",
276
+ " frame_indices = [x[0] + fix_start for x in ranges]\n",
277
+ " elif sample == 'middle':\n",
278
+ " frame_indices = [(x[0] + x[1]) // 2 for x in ranges]\n",
279
+ " else:\n",
280
+ " raise NotImplementedError\n",
281
+ "\n",
282
+ " if len(frame_indices) < num_frames: # padded with last frame\n",
283
+ " padded_frame_indices = [frame_indices[-1]] * num_frames\n",
284
+ " padded_frame_indices[:len(frame_indices)] = frame_indices\n",
285
+ " frame_indices = padded_frame_indices\n",
286
+ " elif \"fps\" in sample: # fps0.5, sequentially sample frames at 0.5 fps\n",
287
+ " output_fps = float(sample[3:])\n",
288
+ " duration = float(clip_length) / input_fps\n",
289
+ " delta = 1 / output_fps # gap between frames, this is also the clip length each frame represents\n",
290
+ " frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta)\n",
291
+ " frame_indices = np.around(frame_seconds * input_fps).astype(int) + start_frame\n",
292
+ " frame_indices = [e for e in frame_indices if e < end_frame]\n",
293
+ " if max_num_frames > 0 and len(frame_indices) > max_num_frames:\n",
294
+ " frame_indices = frame_indices[:max_num_frames]\n",
295
+ " # frame_indices = np.linspace(0 + delta / 2, duration + delta / 2, endpoint=False, num=max_num_frames)\n",
296
+ " else:\n",
297
+ " raise ValueError\n",
298
+ " return frame_indices\n",
299
+ "\n",
300
+ "def read_frames_decord(\n",
301
+ " video_path, num_frames, sample='middle', fix_start=None, \n",
302
+ " max_num_frames=-1, client=None, trimmed30=False, start=None, end=None\n",
303
+ " ):\n",
304
+ " num_threads = 1 if video_path.endswith('.webm') else 0 # make ssv2 happy\n",
305
+ "\n",
306
+ " video_reader = VideoReader(video_path, num_threads=num_threads)\n",
307
+ " vlen = len(video_reader)\n",
308
+ " \n",
309
+ " fps = video_reader.get_avg_fps()\n",
310
+ " duration = vlen / float(fps)\n",
311
+ "\n",
312
+ " frame_indices = get_frame_indices(\n",
313
+ " num_frames, vlen, sample=sample, fix_start=fix_start,\n",
314
+ " input_fps=fps, max_num_frames=max_num_frames, start=start, end=end\n",
315
+ " )\n",
316
+ "\n",
317
+ " frames = video_reader.get_batch(frame_indices) # (T, H, W, C), torch.uint8\n",
318
+ " frames = frames.permute(0, 3, 1, 2) # (T, C, H, W), torch.uint8\n",
319
+ " return frames, frame_indices, duration"
320
+ ]
321
+ },
322
+ {
323
+ "cell_type": "code",
324
+ "execution_count": 3,
325
+ "metadata": {
326
+ "metadata": {}
327
+ },
328
+ "outputs": [],
329
+ "source": [
330
+ "def get_text_feature(model, texts):\n",
331
+ " text_input = model.tokenizer(texts).to(model.device)\n",
332
+ " text_features = model.encode_text(text_input)\n",
333
+ " return text_features\n",
334
+ " \n",
335
+ "def get_similarity(video_feature, text_feature):\n",
336
+ " video_feature = F.normalize(video_feature, dim=-1)\n",
337
+ " text_feature = F.normalize(text_feature, dim=-1)\n",
338
+ " sim_matrix = text_feature @ video_feature.T\n",
339
+ " return sim_matrix"
340
+ ]
341
+ },
342
+ {
343
+ "cell_type": "code",
344
+ "execution_count": 12,
345
+ "metadata": {
346
+ "metadata": {}
347
+ },
348
+ "outputs": [],
349
+ "source": [
350
+ "def get_top_videos(model, text_features, video_features, video_paths, texts):\n",
351
+ " # text_features = get_text_feature(texts)\n",
352
+ "\n",
353
+ " video_features = F.normalize(video_features, dim=-1)\n",
354
+ " text_features = F.normalize(text_features, dim=-1)\n",
355
+ "\n",
356
+ " # print(text_features.shape, video_features.shape)\n",
357
+ " sim_matrix = text_features @ video_features.T\n",
358
+ " # print(sim_matrix.shape)\n",
359
+ "\n",
360
+ " top_k = 5\n",
361
+ " sim_matrix_top_k = torch.topk(sim_matrix, top_k, dim=1)[1]\n",
362
+ " softmax_sim_matrix = F.softmax(sim_matrix, dim=1)\n",
363
+ "\n",
364
+ " retrieval_infos = {}\n",
365
+ " for i in range(len(sim_matrix_top_k)):\n",
366
+ " print(\"\\n\",texts[i])\n",
367
+ " retrieval_infos[texts[i]] = []\n",
368
+ " for j in range(top_k):\n",
369
+ " print(\"top\", j+1, \":\", video_paths[sim_matrix_top_k[i][j]], \"~prob:\", sim_matrix[i][sim_matrix_top_k[i][j]].item())\n",
370
+ " retrieval_infos[texts[i]].append({\"video\": video_paths[sim_matrix_top_k[i][j]], \"prob\": sim_matrix[i][sim_matrix_top_k[i][j]].item(), \"rank\": j+1})\n",
371
+ " return retrieval_infos"
372
+ ]
373
+ },
374
+ {
375
+ "cell_type": "code",
376
+ "execution_count": null,
377
+ "metadata": {
378
+ "metadata": {}
379
+ },
380
+ "outputs": [],
381
+ "source": [
382
+ "if __name__==\"__main__\":\n",
383
+ " video_features = []\n",
384
+ " demo_videos = [\"video-scene-00030.mp4\",\"video-scene-00031.mp4\",\"xinhuashe_test_video/video-scene-00032.mp4\",\"xinhuashe_test_video/video-scene-00033.mp4\",\"video-scene-00034.mp4\"]\n",
385
+ " texts = ['a person talking', 'a logo', 'a building']\n",
386
+ " for video_path in demo_videos:\n",
387
+ " frames, frame_indices, video_duration = read_frames_decord(video_path,8)\n",
388
+ " frames = model.transform(frames).unsqueeze(0).to(model.device)\n",
389
+ " # 获得视频特征\n",
390
+ " with torch.no_grad():\n",
391
+ " video_feature = model.encode_vision(frames, test=True)\n",
392
+ " video_features.append(video_feature)\n",
393
+ " \n",
394
+ " # # 获得文本特征\n",
395
+ " text_features = get_text_feature(model, texts)\n",
396
+ " video_features = torch.cat(video_features, dim=0).to(text_features.dtype).to(config.device)\n",
397
+ " results = get_top_videos(model, text_features, video_features, demo_videos, texts)\n",
398
+ "\n",
399
+ "\n"
400
+ ]
401
+ }
402
+ ],
403
+ "metadata": {
404
+ "kernelspec": {
405
+ "display_name": "base",
406
+ "language": "python",
407
+ "name": "python3"
408
+ },
409
+ "language_info": {
410
+ "codemirror_mode": {
411
+ "name": "ipython",
412
+ "version": 3
413
+ },
414
+ "file_extension": ".py",
415
+ "mimetype": "text/x-python",
416
+ "name": "python",
417
+ "nbconvert_exporter": "python",
418
+ "pygments_lexer": "ipython3",
419
+ "version": "3.10.15"
420
+ }
421
+ },
422
+ "nbformat": 4,
423
+ "nbformat_minor": 2
424
+ }