File size: 14,561 Bytes
ff495b4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 |
import numpy as np
import torch
import logging
logger = logging.getLogger(__name__)
# --------------------------------------------------------
# 3D sine-cosine position embedding
# References:
# MVD: https://github.com/ruiwang2021/mvd/blob/main/modeling_finetune.py
# --------------------------------------------------------
def get_3d_sincos_pos_embed(embed_dim, grid_size, t_size, cls_token=False):
"""
grid_size: int of the grid height and width
t_size: int of the temporal size
return:
pos_embed: [t_size*grid_size*grid_size, embed_dim] or [1+t_size*grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
assert embed_dim % 4 == 0
embed_dim_spatial = embed_dim // 4 * 3
embed_dim_temporal = embed_dim // 4
# spatial
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(
embed_dim_spatial, grid
)
# temporal
grid_t = np.arange(t_size, dtype=np.float32)
pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(
embed_dim_temporal, grid_t
)
# concate: [T, H, W] order
pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :]
pos_embed_temporal = np.repeat(
pos_embed_temporal, grid_size**2, axis=1
) # [T, H*W, D // 4]
pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :]
pos_embed_spatial = np.repeat(
pos_embed_spatial, t_size, axis=0
) # [T, H*W, D // 4 * 3]
pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1)
pos_embed = pos_embed.reshape([-1, embed_dim]) # [T*H*W, D]
if cls_token:
pos_embed = np.concatenate(
[np.zeros([1, embed_dim]), pos_embed], axis=0
)
return pos_embed
# --------------------------------------------------------
# 2D sine-cosine position embedding
# References:
# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
# MoCo v3: https://github.com/facebookresearch/moco-v3
# --------------------------------------------------------
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token:
pos_embed = np.concatenate(
[np.zeros([1, embed_dim]), pos_embed], axis=0
)
return pos_embed
def get_1d_sincos_pos_embed(embed_dim, t_size, cls_token=False):
"""
t_size: int of the temporal size
return:
pos_embed: [t_size, embed_dim] or [1+t_size, embed_dim] (w/ or w/o cls_token)
"""
grid_t = np.arange(t_size, dtype=np.float32)
pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid_t)
if cls_token:
pos_embed = np.concatenate(
[np.zeros([1, embed_dim]), pos_embed], axis=0
)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(
embed_dim // 2, grid[0]
) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(
embed_dim // 2, grid[1]
) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float32)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
def interpolate_pos_embed(checkpoint_model, model, orig_t_size=4, pos_name='vision_encoder.pos_embed'):
if pos_name in checkpoint_model:
pos_embed_checkpoint = checkpoint_model[pos_name]
embedding_size = pos_embed_checkpoint.shape[-1] # channel dim
num_patches = model.patch_embed.num_patches #
num_extra_tokens = model.pos_embed.shape[-2] - num_patches # 0/1
# we use 4 frames for pretraining
new_t_size = model.T
# height (== width) for the checkpoint position embedding
orig_size = int(((pos_embed_checkpoint.shape[-2] - num_extra_tokens)//(orig_t_size)) ** 0.5)
# height (== width) for the new position embedding
new_size = int((num_patches // (new_t_size))** 0.5)
# class_token and dist_token are kept unchanged
if orig_t_size != new_t_size:
logger.info(f"Temporal interpolate from {orig_t_size} to {new_t_size} ({pos_name})")
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
# B, L, C -> B, T, HW, C -> BHW, C, T (B = 1)
pos_tokens = pos_tokens.view(1, orig_t_size, -1, embedding_size)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, embedding_size, orig_t_size)
pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=new_t_size, mode='linear')
pos_tokens = pos_tokens.view(1, -1, embedding_size, new_t_size)
pos_tokens = pos_tokens.permute(0, 3, 1, 2).reshape(1, -1, embedding_size)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
checkpoint_model[pos_name] = new_pos_embed
pos_embed_checkpoint = new_pos_embed
# class_token and dist_token are kept unchanged
if orig_size != new_size:
logger.info(f"Position interpolate from {orig_size}x{orig_size} to {new_size}x{new_size} ({pos_name})")
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
# B, L, C -> BT, H, W, C -> BT, C, H, W
pos_tokens = pos_tokens.reshape(-1, new_t_size, orig_size, orig_size, embedding_size)
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
# BT, C, H, W -> BT, H, W, C -> B, T, H, W, C
pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, new_t_size, new_size, new_size, embedding_size)
pos_tokens = pos_tokens.flatten(1, 3) # B, L, C
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
checkpoint_model[pos_name] = new_pos_embed
def interpolate_pos_embed_internvideo2(checkpoint_model, model, orig_t_size = 8):
# interpolate position embedding
for pos_name in ['pos_embed', 'clip_pos_embed']:
if pos_name in checkpoint_model:
pos_embed_checkpoint = checkpoint_model[pos_name]
embedding_size = pos_embed_checkpoint.shape[-1] # channel dim
num_patches = model.patch_embed.num_patches #
num_extra_tokens = model.pos_embed.shape[-2] - num_patches # 0/1
# we use 8 frames for pretraining
# new_t_size = args.num_frames * args.num_segments // model.patch_embed.tubelet_size
new_t_size = model.num_frames // model.tubelet_size
# height (== width) for the checkpoint position embedding
orig_size = int(((pos_embed_checkpoint.shape[-2] - num_extra_tokens)//(orig_t_size)) ** 0.5)
# height (== width) for the new position embedding
new_size = int((num_patches // (new_t_size))** 0.5)
# class_token and dist_token are kept unchanged
if orig_t_size != new_t_size:
logger.info(f"Temporal interpolate from {orig_t_size} to {new_t_size} ({pos_name})")
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
# B, L, C -> B, T, HW, C -> BHW, C, T (B = 1)
pos_tokens = pos_tokens.view(1, orig_t_size, -1, embedding_size)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, embedding_size, orig_t_size)
pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=new_t_size, mode='linear')
pos_tokens = pos_tokens.view(1, -1, embedding_size, new_t_size)
pos_tokens = pos_tokens.permute(0, 3, 1, 2).reshape(1, -1, embedding_size)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
checkpoint_model[pos_name] = new_pos_embed
pos_embed_checkpoint = new_pos_embed
# class_token and dist_token are kept unchanged
if orig_size != new_size:
logger.info(f"Position interpolate from {orig_size}x{orig_size} to {new_size}x{new_size} ({pos_name})")
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
# B, L, C -> BT, H, W, C -> BT, C, H, W
pos_tokens = pos_tokens.reshape(-1, new_t_size, orig_size, orig_size, embedding_size)
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
# BT, C, H, W -> BT, H, W, C -> B, T, H, W, C
pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, new_t_size, new_size, new_size, embedding_size)
pos_tokens = pos_tokens.flatten(1, 3) # B, L, C
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
checkpoint_model[pos_name] = new_pos_embed
if 'pos_embed_spatial' in checkpoint_model or 'pos_embed_temporal' in checkpoint_model:
raise NotImplementedError
def interpolate_pos_embed_internvideo2_new(checkpoint_model, model, orig_t_size = 8):
pos_names = []
for k in checkpoint_model.keys():
if ('pos_embed' in k or 'clip_pos_embed' in k) and 'img_pos_embed' not in k:
pos_names.append(k)
logger.info(f"pos names list for interpolating: {pos_names}")
assert len(pos_names) > 0, checkpoint_model.keys()
if 'pos_embed_spatial' in checkpoint_model.keys() or 'pos_embed_temporal' in checkpoint_model.keys():
raise NotImplementedError
# interpolate position embedding
for pos_name in pos_names:
pos_embed_checkpoint = checkpoint_model[pos_name]
embedding_size = pos_embed_checkpoint.shape[-1] # channel dim
num_patches = model.patch_embed.num_patches #
num_extra_tokens = model.pos_embed.shape[-2] - num_patches # 0/1
# we use 8 frames for pretraining
# new_t_size = args.num_frames * args.num_segments // model.patch_embed.tubelet_size
new_t_size = model.num_frames // model.tubelet_size
# height (== width) for the checkpoint position embedding
orig_size = int(((pos_embed_checkpoint.shape[-2] - num_extra_tokens)//(orig_t_size)) ** 0.5)
# height (== width) for the new position embedding
new_size = int((num_patches // (new_t_size))** 0.5)
# class_token and dist_token are kept unchanged
if orig_t_size != new_t_size:
logger.info(f"Temporal interpolate from {orig_t_size} to {new_t_size} ({pos_name})")
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
# B, L, C -> B, T, HW, C -> BHW, C, T (B = 1)
pos_tokens = pos_tokens.view(1, orig_t_size, -1, embedding_size)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, embedding_size, orig_t_size)
pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=new_t_size, mode='linear')
pos_tokens = pos_tokens.view(1, -1, embedding_size, new_t_size)
pos_tokens = pos_tokens.permute(0, 3, 1, 2).reshape(1, -1, embedding_size)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
checkpoint_model[pos_name] = new_pos_embed
pos_embed_checkpoint = new_pos_embed
# class_token and dist_token are kept unchanged
if orig_size != new_size:
logger.info(f"Position interpolate from {orig_size}x{orig_size} to {new_size}x{new_size} ({pos_name})")
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
# B, L, C -> BT, H, W, C -> BT, C, H, W
pos_tokens = pos_tokens.reshape(-1, new_t_size, orig_size, orig_size, embedding_size)
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
# BT, C, H, W -> BT, H, W, C -> B, T, H, W, C
pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, new_t_size, new_size, new_size, embedding_size)
pos_tokens = pos_tokens.flatten(1, 3) # B, L, C
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
checkpoint_model[pos_name] = new_pos_embed |