DefOs9 commited on
Commit
f580f2d
·
verified ·
1 Parent(s): 18812f4

Fix image embedding logic to be mps-compatible

Browse files

Addresses the assertion error raised on mps machines. Cf. https://huggingface.co/microsoft/Phi-4-multimodal-instruct/discussions/12

- MPS changes:
- `.bool()` instead of `.type(torch.BoolTensor)`
- Avoid `index_put` issues by having an mps-specific logical block.
- The `temp_len` variable in the assertion was never used anyway, so I removed the variable and the offending assertion.
- Various clean up of comments and code.

Files changed (1) hide show
  1. modeling_phi4mm.py +70 -40
modeling_phi4mm.py CHANGED
@@ -325,7 +325,7 @@ class Phi4MMImageEmbedding(nn.Module):
325
  bs = img_embeds.shape[0]
326
  # Nx(HW)xC
327
  if image_attention_mask is not None and len(image_attention_mask) > 0:
328
- img_features = self.get_img_features(img_embeds.flatten(0, 1), attention_mask=image_attention_mask.type(torch.BoolTensor).flatten(0,1).to(target_device))
329
  else:
330
  img_features = self.get_img_features(img_embeds.flatten(0, 1))
331
 
@@ -337,13 +337,12 @@ class Phi4MMImageEmbedding(nn.Module):
337
 
338
  assert base_feat_height == base_feat_height_target and base_feat_width == base_feat_height_target, f'base_feat_height: {base_feat_height}, base_feat_width: {base_feat_width}, expect {base_feat_height_target} features for hd transform'
339
 
340
- # bs x max_num_crops x (24x24) x C
341
  img_features = img_features.view(bs, -1, base_feat_height * base_feat_width, self.image_dim_out)
342
  C = self.image_dim_out
343
  H = base_feat_height
344
 
345
  output_imgs = []
346
- output_len = []
347
  # training is tensor, inference is list
348
  if isinstance(img_sizes, torch.Tensor):
349
  img_sizes = img_sizes.view(-1, 2)
@@ -353,39 +352,71 @@ class Phi4MMImageEmbedding(nn.Module):
353
  w = w // base_resolution
354
  B_ = h * w
355
 
356
- # 1 x (24x24) x 1024
357
  global_img_feature = img_features[_bs, :1]
358
 
359
- # 1 x 12 x 12 x 4096
360
- glb_img = global_img_feature.reshape(1,H,H,C).reshape(1,H//base_feat_height_reduction,base_feat_height_reduction,H//base_feat_height_reduction,base_feat_height_reduction,C).contiguous().permute(0,1,3,2,4,5).reshape(1,H//base_feat_height_reduction,H//base_feat_height_reduction,base_feat_height_reduction*base_feat_height_reduction*C).contiguous()
361
- temp_glb_GN = self.sub_GN.repeat(1, H//base_feat_height_reduction, 1, 1)
 
 
 
 
 
 
 
 
 
 
362
 
363
- # 1 x 156 x 4096
364
- glb_img = torch.cat([glb_img, temp_glb_GN], dim=2).reshape(1,-1,base_feat_height_reduction*base_feat_height_reduction*C)
 
 
 
365
 
366
- # (max_num_crops-1) x (12x12) x C
367
  sub_img = img_features[_bs, 1:]
368
- # 16x574x1024
369
  # get rid of padding sub_img
370
  sub_img = sub_img[:B_]
371
 
372
- # (num_crops, 12, 2, 12, 2, 1024) -> (num_crops, 12, 12, 2, 2, 1024) -> (num_crops, 12*12, 4*1024)
373
- sub_img = sub_img.reshape(B_,H,H,C).reshape(B_,H//base_feat_height_reduction,base_feat_height_reduction,H//base_feat_height_reduction,base_feat_height_reduction,C).contiguous().permute(0,1,3,2,4,5).reshape(B_,-1,base_feat_height_reduction*base_feat_height_reduction*C).contiguous()
374
- sub_img = sub_img.reshape(1, h, w, base_feat_height // base_feat_height_reduction, base_feat_width // base_feat_height_reduction, -1).permute(0,1,3,2,4,5).reshape(1,h*base_feat_height//base_feat_height_reduction,w*base_feat_width//base_feat_height_reduction,base_feat_height_reduction*base_feat_height_reduction*C)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
375
 
376
  if image_attention_mask is not None and len(image_attention_mask) > 0:
377
- reshaped_image_attention_mask = image_attention_mask[_bs,1:B_+1,0::2,0::2].reshape(1, h, w, base_feat_height // base_feat_height_reduction, base_feat_width // base_feat_height_reduction).permute(0,1,3,2,4).reshape(1,h*base_feat_height//base_feat_height_reduction,w*base_feat_width//base_feat_height_reduction)
378
- useful_height = int(reshaped_image_attention_mask[0,:,0].sum().item())
379
- useful_width = int(reshaped_image_attention_mask[0,0,:].sum().item())
380
- sub_img = sub_img[:,:useful_height, :useful_width]
 
 
 
 
 
 
 
381
  temp_sub_GN = self.sub_GN.repeat(1, useful_height, 1, 1)
382
- temp_len = int(image_attention_mask[_bs,:B_+1,0::2,0::2].sum().item()) + (useful_height+1) + base_feat_height//base_feat_height_reduction
383
  else:
384
- temp_sub_GN = self.sub_GN.repeat(1, h*base_feat_height//base_feat_height_reduction, 1, 1)
385
- temp_len = int((h*w+1)*self.num_img_tokens+ 1 + (h+1)*base_feat_height//base_feat_height_reduction)
386
 
387
- sub_img = torch.cat([sub_img, temp_sub_GN], dim=2).reshape(1,-1,base_feat_height_reduction*base_feat_height_reduction*C)
388
- # (1, num_img_tokens, 1024*4)
389
 
390
  # glb + sub
391
  if self.hd_transform_order == 'glb_sub':
@@ -395,17 +426,11 @@ class Phi4MMImageEmbedding(nn.Module):
395
  else:
396
  raise NotImplementedError(f'hd_transform_order = {self.hd_transform_order}, not implemented')
397
 
398
- #temp_len = int((h*w+1)*144 + 1 + (h+1)*12)
399
- assert temp_len == output_imgs[-1].shape[1], f'temp_len: {temp_len}, output_imgs[-1].shape[1]: {output_imgs[-1].shape[1]}'
400
- output_len.append(temp_len)
401
-
402
- num_img_tokens = output_len
403
  img_set_tensor = []
404
  for _output_img in output_imgs:
405
  img_feature_proj = self.img_projection(_output_img.to(target_device).to(target_dtype))
406
  img_set_tensor.append(img_feature_proj)
407
- #logger.info(f'img_embeds size: {img_embeds.size()}, image sizes: {img_sizes} loading time {datetime.now() - start_time}')
408
- #assert sum(num_img_tokens) == len(g_values), f'(branch 1) sum(num_img_tokens): {sum(num_img_tokens)}, g_values size: {len(g_values)}, g_values {g_values}'
409
 
410
  else:
411
  raise NotImplementedError
@@ -420,7 +445,7 @@ class Phi4MMImageEmbedding(nn.Module):
420
  self.get_img_features(img_embeds)
421
  .to(target_device)
422
  .to(target_dtype)
423
- .reshape(-1, 1024)
424
  )
425
  if self.use_hd_transform:
426
  img_set_tensor = self.img_projection(tt.reshape(-1, self.image_dim_out*self.base_feat_height_reduction**2) * self.glb_GN[0] * self.sub_GN[0, 0])
@@ -442,14 +467,19 @@ class Phi4MMImageEmbedding(nn.Module):
442
  # Shape: (merged_N_tokens, C)
443
  merged_img_set_tensor = torch.cat(img_set_tensor, dim=1).squeeze(0)
444
  merged_img_set_tensor = merged_img_set_tensor.to(hidden_states.dtype).to(hidden_states.device)
445
- # Temporarily disable autocast to avoid issue on bf16 tensors
446
- # Ref: https://github.com/pytorch/pytorch/issues/132715
447
- with torch.autocast(device_type=hidden_states.device.type, enabled=False):
448
- new_hidden_states = hidden_states.index_put(
449
- indices=positions_tuple,
450
- values=merged_img_set_tensor,
451
- accumulate=False
452
- )
 
 
 
 
 
453
  hidden_states = new_hidden_states
454
  else:
455
  raise NotImplementedError
@@ -2096,7 +2126,7 @@ class Phi4MMForCausalLM(Phi4MMPreTrainedModel, GenerationMixin):
2096
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
2097
 
2098
  if isinstance(input_mode, torch.Tensor):
2099
- # len(input_mode) == num_beams in beam search, and all elements of input_mode should have the same value
2100
  input_mode = input_mode[0].item()
2101
  input_mode = InputMode(input_mode)
2102
 
 
325
  bs = img_embeds.shape[0]
326
  # Nx(HW)xC
327
  if image_attention_mask is not None and len(image_attention_mask) > 0:
328
+ img_features = self.get_img_features(img_embeds.flatten(0, 1), attention_mask=image_attention_mask.bool().flatten(0,1).to(target_device))
329
  else:
330
  img_features = self.get_img_features(img_embeds.flatten(0, 1))
331
 
 
337
 
338
  assert base_feat_height == base_feat_height_target and base_feat_width == base_feat_height_target, f'base_feat_height: {base_feat_height}, base_feat_width: {base_feat_width}, expect {base_feat_height_target} features for hd transform'
339
 
340
+ # bs x max_num_crops x (HxH) x C
341
  img_features = img_features.view(bs, -1, base_feat_height * base_feat_width, self.image_dim_out)
342
  C = self.image_dim_out
343
  H = base_feat_height
344
 
345
  output_imgs = []
 
346
  # training is tensor, inference is list
347
  if isinstance(img_sizes, torch.Tensor):
348
  img_sizes = img_sizes.view(-1, 2)
 
352
  w = w // base_resolution
353
  B_ = h * w
354
 
355
+ # 1 x (HxH) x C
356
  global_img_feature = img_features[_bs, :1]
357
 
358
+ # 1 x H x H x C
359
+ glb_img = (
360
+ global_img_feature
361
+ .reshape(1, H, H, C)
362
+ .reshape(1, H // base_feat_height_reduction, base_feat_height_reduction,
363
+ H // base_feat_height_reduction, base_feat_height_reduction, C)
364
+ .contiguous()
365
+ .permute(0, 1, 3, 2, 4, 5)
366
+ .reshape(1, H // base_feat_height_reduction, H // base_feat_height_reduction,
367
+ base_feat_height_reduction * base_feat_height_reduction * C)
368
+ .contiguous()
369
+ )
370
+ temp_glb_GN = self.sub_GN.repeat(1, H // base_feat_height_reduction, 1, 1)
371
 
372
+ # 1 x (HxH+H) x C
373
+ glb_img = (
374
+ torch.cat([glb_img, temp_glb_GN], dim=2)
375
+ .reshape(1, -1, base_feat_height_reduction * base_feat_height_reduction * C)
376
+ )
377
 
378
+ # (max_num_crops-1) x (HxH) x C
379
  sub_img = img_features[_bs, 1:]
 
380
  # get rid of padding sub_img
381
  sub_img = sub_img[:B_]
382
 
383
+ sub_img = (
384
+ sub_img
385
+ .reshape(B_, H, H, C)
386
+ .reshape(B_, H // base_feat_height_reduction, base_feat_height_reduction,
387
+ H // base_feat_height_reduction, base_feat_height_reduction, C)
388
+ .contiguous()
389
+ .permute(0, 1, 3, 2, 4, 5)
390
+ .reshape(B_, -1, base_feat_height_reduction * base_feat_height_reduction * C)
391
+ .contiguous()
392
+ )
393
+ sub_img = (
394
+ sub_img
395
+ .reshape(1, h, w, base_feat_height // base_feat_height_reduction,
396
+ base_feat_width // base_feat_height_reduction, -1)
397
+ .permute(0, 1, 3, 2, 4, 5)
398
+ .reshape(1, h * base_feat_height // base_feat_height_reduction,
399
+ w * base_feat_width // base_feat_height_reduction,
400
+ base_feat_height_reduction * base_feat_height_reduction * C)
401
+ )
402
 
403
  if image_attention_mask is not None and len(image_attention_mask) > 0:
404
+ reshaped_image_attention_mask = (
405
+ image_attention_mask[_bs, 1:B_ + 1, 0::2, 0::2]
406
+ .reshape(1, h, w, base_feat_height // base_feat_height_reduction,
407
+ base_feat_width // base_feat_height_reduction)
408
+ .permute(0, 1, 3, 2, 4)
409
+ .reshape(1, h * base_feat_height // base_feat_height_reduction,
410
+ w * base_feat_width // base_feat_height_reduction)
411
+ )
412
+ useful_height = int(reshaped_image_attention_mask[0, :, 0].sum().item())
413
+ useful_width = int(reshaped_image_attention_mask[0, 0, :].sum().item())
414
+ sub_img = sub_img[:, :useful_height, :useful_width]
415
  temp_sub_GN = self.sub_GN.repeat(1, useful_height, 1, 1)
 
416
  else:
417
+ temp_sub_GN = self.sub_GN.repeat(1, h * base_feat_height // base_feat_height_reduction, 1, 1)
 
418
 
419
+ sub_img = torch.cat([sub_img, temp_sub_GN], dim=2).reshape(1, -1, base_feat_height_reduction * base_feat_height_reduction * C)
 
420
 
421
  # glb + sub
422
  if self.hd_transform_order == 'glb_sub':
 
426
  else:
427
  raise NotImplementedError(f'hd_transform_order = {self.hd_transform_order}, not implemented')
428
 
 
 
 
 
 
429
  img_set_tensor = []
430
  for _output_img in output_imgs:
431
  img_feature_proj = self.img_projection(_output_img.to(target_device).to(target_dtype))
432
  img_set_tensor.append(img_feature_proj)
433
+ # logger.info(f'img_embeds size: {img_embeds.size()}, image sizes: {img_sizes} loading time {datetime.now() - start_time}')
 
434
 
435
  else:
436
  raise NotImplementedError
 
445
  self.get_img_features(img_embeds)
446
  .to(target_device)
447
  .to(target_dtype)
448
+ .reshape(-1, self.image_dim_out)
449
  )
450
  if self.use_hd_transform:
451
  img_set_tensor = self.img_projection(tt.reshape(-1, self.image_dim_out*self.base_feat_height_reduction**2) * self.glb_GN[0] * self.sub_GN[0, 0])
 
467
  # Shape: (merged_N_tokens, C)
468
  merged_img_set_tensor = torch.cat(img_set_tensor, dim=1).squeeze(0)
469
  merged_img_set_tensor = merged_img_set_tensor.to(hidden_states.dtype).to(hidden_states.device)
470
+ if hidden_states.device.type == "mps":
471
+ # For MPS, assign using direct indexing to avoid index_put issues.
472
+ new_hidden_states = hidden_states.clone()
473
+ new_hidden_states[positions_tuple] = merged_img_set_tensor
474
+ else:
475
+ # Temporarily disable autocast to avoid issue on bf16 tensors
476
+ # Ref: https://github.com/pytorch/pytorch/issues/132715
477
+ with torch.autocast(device_type=hidden_states.device.type, enabled=False):
478
+ new_hidden_states = hidden_states.index_put(
479
+ indices=positions_tuple,
480
+ values=merged_img_set_tensor,
481
+ accumulate=False
482
+ )
483
  hidden_states = new_hidden_states
484
  else:
485
  raise NotImplementedError
 
2126
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
2127
 
2128
  if isinstance(input_mode, torch.Tensor):
2129
+ assert len(input_mode) == 1
2130
  input_mode = input_mode[0].item()
2131
  input_mode = InputMode(input_mode)
2132