Fix image embedding logic to be mps-compatible
#45
by
DefOs9
- opened
- modeling_phi4mm.py +70 -40
modeling_phi4mm.py
CHANGED
@@ -325,7 +325,7 @@ class Phi4MMImageEmbedding(nn.Module):
|
|
325 |
bs = img_embeds.shape[0]
|
326 |
# Nx(HW)xC
|
327 |
if image_attention_mask is not None and len(image_attention_mask) > 0:
|
328 |
-
img_features = self.get_img_features(img_embeds.flatten(0, 1), attention_mask=image_attention_mask.
|
329 |
else:
|
330 |
img_features = self.get_img_features(img_embeds.flatten(0, 1))
|
331 |
|
@@ -337,13 +337,12 @@ class Phi4MMImageEmbedding(nn.Module):
|
|
337 |
|
338 |
assert base_feat_height == base_feat_height_target and base_feat_width == base_feat_height_target, f'base_feat_height: {base_feat_height}, base_feat_width: {base_feat_width}, expect {base_feat_height_target} features for hd transform'
|
339 |
|
340 |
-
# bs x max_num_crops x (
|
341 |
img_features = img_features.view(bs, -1, base_feat_height * base_feat_width, self.image_dim_out)
|
342 |
C = self.image_dim_out
|
343 |
H = base_feat_height
|
344 |
|
345 |
output_imgs = []
|
346 |
-
output_len = []
|
347 |
# training is tensor, inference is list
|
348 |
if isinstance(img_sizes, torch.Tensor):
|
349 |
img_sizes = img_sizes.view(-1, 2)
|
@@ -353,39 +352,71 @@ class Phi4MMImageEmbedding(nn.Module):
|
|
353 |
w = w // base_resolution
|
354 |
B_ = h * w
|
355 |
|
356 |
-
# 1 x (
|
357 |
global_img_feature = img_features[_bs, :1]
|
358 |
|
359 |
-
# 1 x
|
360 |
-
glb_img =
|
361 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
362 |
|
363 |
-
# 1 x
|
364 |
-
glb_img =
|
|
|
|
|
|
|
365 |
|
366 |
-
# (max_num_crops-1) x (
|
367 |
sub_img = img_features[_bs, 1:]
|
368 |
-
# 16x574x1024
|
369 |
# get rid of padding sub_img
|
370 |
sub_img = sub_img[:B_]
|
371 |
|
372 |
-
|
373 |
-
|
374 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
375 |
|
376 |
if image_attention_mask is not None and len(image_attention_mask) > 0:
|
377 |
-
reshaped_image_attention_mask =
|
378 |
-
|
379 |
-
|
380 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
381 |
temp_sub_GN = self.sub_GN.repeat(1, useful_height, 1, 1)
|
382 |
-
temp_len = int(image_attention_mask[_bs,:B_+1,0::2,0::2].sum().item()) + (useful_height+1) + base_feat_height//base_feat_height_reduction
|
383 |
else:
|
384 |
-
temp_sub_GN = self.sub_GN.repeat(1, h*base_feat_height//base_feat_height_reduction, 1, 1)
|
385 |
-
temp_len = int((h*w+1)*self.num_img_tokens+ 1 + (h+1)*base_feat_height//base_feat_height_reduction)
|
386 |
|
387 |
-
sub_img = torch.cat([sub_img, temp_sub_GN], dim=2).reshape(1
|
388 |
-
# (1, num_img_tokens, 1024*4)
|
389 |
|
390 |
# glb + sub
|
391 |
if self.hd_transform_order == 'glb_sub':
|
@@ -395,17 +426,11 @@ class Phi4MMImageEmbedding(nn.Module):
|
|
395 |
else:
|
396 |
raise NotImplementedError(f'hd_transform_order = {self.hd_transform_order}, not implemented')
|
397 |
|
398 |
-
#temp_len = int((h*w+1)*144 + 1 + (h+1)*12)
|
399 |
-
assert temp_len == output_imgs[-1].shape[1], f'temp_len: {temp_len}, output_imgs[-1].shape[1]: {output_imgs[-1].shape[1]}'
|
400 |
-
output_len.append(temp_len)
|
401 |
-
|
402 |
-
num_img_tokens = output_len
|
403 |
img_set_tensor = []
|
404 |
for _output_img in output_imgs:
|
405 |
img_feature_proj = self.img_projection(_output_img.to(target_device).to(target_dtype))
|
406 |
img_set_tensor.append(img_feature_proj)
|
407 |
-
#logger.info(f'img_embeds size: {img_embeds.size()}, image sizes: {img_sizes} loading time {datetime.now() - start_time}')
|
408 |
-
#assert sum(num_img_tokens) == len(g_values), f'(branch 1) sum(num_img_tokens): {sum(num_img_tokens)}, g_values size: {len(g_values)}, g_values {g_values}'
|
409 |
|
410 |
else:
|
411 |
raise NotImplementedError
|
@@ -420,7 +445,7 @@ class Phi4MMImageEmbedding(nn.Module):
|
|
420 |
self.get_img_features(img_embeds)
|
421 |
.to(target_device)
|
422 |
.to(target_dtype)
|
423 |
-
.reshape(-1,
|
424 |
)
|
425 |
if self.use_hd_transform:
|
426 |
img_set_tensor = self.img_projection(tt.reshape(-1, self.image_dim_out*self.base_feat_height_reduction**2) * self.glb_GN[0] * self.sub_GN[0, 0])
|
@@ -442,14 +467,19 @@ class Phi4MMImageEmbedding(nn.Module):
|
|
442 |
# Shape: (merged_N_tokens, C)
|
443 |
merged_img_set_tensor = torch.cat(img_set_tensor, dim=1).squeeze(0)
|
444 |
merged_img_set_tensor = merged_img_set_tensor.to(hidden_states.dtype).to(hidden_states.device)
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
new_hidden_states =
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
)
|
|
|
|
|
|
|
|
|
|
|
453 |
hidden_states = new_hidden_states
|
454 |
else:
|
455 |
raise NotImplementedError
|
@@ -2096,7 +2126,7 @@ class Phi4MMForCausalLM(Phi4MMPreTrainedModel, GenerationMixin):
|
|
2096 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
2097 |
|
2098 |
if isinstance(input_mode, torch.Tensor):
|
2099 |
-
|
2100 |
input_mode = input_mode[0].item()
|
2101 |
input_mode = InputMode(input_mode)
|
2102 |
|
|
|
325 |
bs = img_embeds.shape[0]
|
326 |
# Nx(HW)xC
|
327 |
if image_attention_mask is not None and len(image_attention_mask) > 0:
|
328 |
+
img_features = self.get_img_features(img_embeds.flatten(0, 1), attention_mask=image_attention_mask.bool().flatten(0,1).to(target_device))
|
329 |
else:
|
330 |
img_features = self.get_img_features(img_embeds.flatten(0, 1))
|
331 |
|
|
|
337 |
|
338 |
assert base_feat_height == base_feat_height_target and base_feat_width == base_feat_height_target, f'base_feat_height: {base_feat_height}, base_feat_width: {base_feat_width}, expect {base_feat_height_target} features for hd transform'
|
339 |
|
340 |
+
# bs x max_num_crops x (HxH) x C
|
341 |
img_features = img_features.view(bs, -1, base_feat_height * base_feat_width, self.image_dim_out)
|
342 |
C = self.image_dim_out
|
343 |
H = base_feat_height
|
344 |
|
345 |
output_imgs = []
|
|
|
346 |
# training is tensor, inference is list
|
347 |
if isinstance(img_sizes, torch.Tensor):
|
348 |
img_sizes = img_sizes.view(-1, 2)
|
|
|
352 |
w = w // base_resolution
|
353 |
B_ = h * w
|
354 |
|
355 |
+
# 1 x (HxH) x C
|
356 |
global_img_feature = img_features[_bs, :1]
|
357 |
|
358 |
+
# 1 x H x H x C
|
359 |
+
glb_img = (
|
360 |
+
global_img_feature
|
361 |
+
.reshape(1, H, H, C)
|
362 |
+
.reshape(1, H // base_feat_height_reduction, base_feat_height_reduction,
|
363 |
+
H // base_feat_height_reduction, base_feat_height_reduction, C)
|
364 |
+
.contiguous()
|
365 |
+
.permute(0, 1, 3, 2, 4, 5)
|
366 |
+
.reshape(1, H // base_feat_height_reduction, H // base_feat_height_reduction,
|
367 |
+
base_feat_height_reduction * base_feat_height_reduction * C)
|
368 |
+
.contiguous()
|
369 |
+
)
|
370 |
+
temp_glb_GN = self.sub_GN.repeat(1, H // base_feat_height_reduction, 1, 1)
|
371 |
|
372 |
+
# 1 x (HxH+H) x C
|
373 |
+
glb_img = (
|
374 |
+
torch.cat([glb_img, temp_glb_GN], dim=2)
|
375 |
+
.reshape(1, -1, base_feat_height_reduction * base_feat_height_reduction * C)
|
376 |
+
)
|
377 |
|
378 |
+
# (max_num_crops-1) x (HxH) x C
|
379 |
sub_img = img_features[_bs, 1:]
|
|
|
380 |
# get rid of padding sub_img
|
381 |
sub_img = sub_img[:B_]
|
382 |
|
383 |
+
sub_img = (
|
384 |
+
sub_img
|
385 |
+
.reshape(B_, H, H, C)
|
386 |
+
.reshape(B_, H // base_feat_height_reduction, base_feat_height_reduction,
|
387 |
+
H // base_feat_height_reduction, base_feat_height_reduction, C)
|
388 |
+
.contiguous()
|
389 |
+
.permute(0, 1, 3, 2, 4, 5)
|
390 |
+
.reshape(B_, -1, base_feat_height_reduction * base_feat_height_reduction * C)
|
391 |
+
.contiguous()
|
392 |
+
)
|
393 |
+
sub_img = (
|
394 |
+
sub_img
|
395 |
+
.reshape(1, h, w, base_feat_height // base_feat_height_reduction,
|
396 |
+
base_feat_width // base_feat_height_reduction, -1)
|
397 |
+
.permute(0, 1, 3, 2, 4, 5)
|
398 |
+
.reshape(1, h * base_feat_height // base_feat_height_reduction,
|
399 |
+
w * base_feat_width // base_feat_height_reduction,
|
400 |
+
base_feat_height_reduction * base_feat_height_reduction * C)
|
401 |
+
)
|
402 |
|
403 |
if image_attention_mask is not None and len(image_attention_mask) > 0:
|
404 |
+
reshaped_image_attention_mask = (
|
405 |
+
image_attention_mask[_bs, 1:B_ + 1, 0::2, 0::2]
|
406 |
+
.reshape(1, h, w, base_feat_height // base_feat_height_reduction,
|
407 |
+
base_feat_width // base_feat_height_reduction)
|
408 |
+
.permute(0, 1, 3, 2, 4)
|
409 |
+
.reshape(1, h * base_feat_height // base_feat_height_reduction,
|
410 |
+
w * base_feat_width // base_feat_height_reduction)
|
411 |
+
)
|
412 |
+
useful_height = int(reshaped_image_attention_mask[0, :, 0].sum().item())
|
413 |
+
useful_width = int(reshaped_image_attention_mask[0, 0, :].sum().item())
|
414 |
+
sub_img = sub_img[:, :useful_height, :useful_width]
|
415 |
temp_sub_GN = self.sub_GN.repeat(1, useful_height, 1, 1)
|
|
|
416 |
else:
|
417 |
+
temp_sub_GN = self.sub_GN.repeat(1, h * base_feat_height // base_feat_height_reduction, 1, 1)
|
|
|
418 |
|
419 |
+
sub_img = torch.cat([sub_img, temp_sub_GN], dim=2).reshape(1, -1, base_feat_height_reduction * base_feat_height_reduction * C)
|
|
|
420 |
|
421 |
# glb + sub
|
422 |
if self.hd_transform_order == 'glb_sub':
|
|
|
426 |
else:
|
427 |
raise NotImplementedError(f'hd_transform_order = {self.hd_transform_order}, not implemented')
|
428 |
|
|
|
|
|
|
|
|
|
|
|
429 |
img_set_tensor = []
|
430 |
for _output_img in output_imgs:
|
431 |
img_feature_proj = self.img_projection(_output_img.to(target_device).to(target_dtype))
|
432 |
img_set_tensor.append(img_feature_proj)
|
433 |
+
# logger.info(f'img_embeds size: {img_embeds.size()}, image sizes: {img_sizes} loading time {datetime.now() - start_time}')
|
|
|
434 |
|
435 |
else:
|
436 |
raise NotImplementedError
|
|
|
445 |
self.get_img_features(img_embeds)
|
446 |
.to(target_device)
|
447 |
.to(target_dtype)
|
448 |
+
.reshape(-1, self.image_dim_out)
|
449 |
)
|
450 |
if self.use_hd_transform:
|
451 |
img_set_tensor = self.img_projection(tt.reshape(-1, self.image_dim_out*self.base_feat_height_reduction**2) * self.glb_GN[0] * self.sub_GN[0, 0])
|
|
|
467 |
# Shape: (merged_N_tokens, C)
|
468 |
merged_img_set_tensor = torch.cat(img_set_tensor, dim=1).squeeze(0)
|
469 |
merged_img_set_tensor = merged_img_set_tensor.to(hidden_states.dtype).to(hidden_states.device)
|
470 |
+
if hidden_states.device.type == "mps":
|
471 |
+
# For MPS, assign using direct indexing to avoid index_put issues.
|
472 |
+
new_hidden_states = hidden_states.clone()
|
473 |
+
new_hidden_states[positions_tuple] = merged_img_set_tensor
|
474 |
+
else:
|
475 |
+
# Temporarily disable autocast to avoid issue on bf16 tensors
|
476 |
+
# Ref: https://github.com/pytorch/pytorch/issues/132715
|
477 |
+
with torch.autocast(device_type=hidden_states.device.type, enabled=False):
|
478 |
+
new_hidden_states = hidden_states.index_put(
|
479 |
+
indices=positions_tuple,
|
480 |
+
values=merged_img_set_tensor,
|
481 |
+
accumulate=False
|
482 |
+
)
|
483 |
hidden_states = new_hidden_states
|
484 |
else:
|
485 |
raise NotImplementedError
|
|
|
2126 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
2127 |
|
2128 |
if isinstance(input_mode, torch.Tensor):
|
2129 |
+
assert len(input_mode) == 1
|
2130 |
input_mode = input_mode[0].item()
|
2131 |
input_mode = InputMode(input_mode)
|
2132 |
|