RuntimeError: Expected number of channels in input to be divisible by num_groups, but got input of shape [1, 5, 512, 7, 7] and num_groups=16

#2
by jlaitue - opened

Do you guys have any insight into this issue I ran while executing the inference demo available on the repo:

  1. I executed the original inference code you posted
  2. Loading the provided example just to run a single example of the model
  3. Weights for ViT, Llama and Pytorch are correctly loaded into the cache
  4. I believe the issue occurs because at some point an extra dimension is added to the tensor

SHAPE!!! torch.Size([1, 64, 112, 112])
SHAPE!!! torch.Size([1, 64, 56, 56])
SHAPE!!! torch.Size([1, 64, 56, 56])
SHAPE!!! torch.Size([1, 64, 56, 56])
SHAPE!!! torch.Size([1, 64, 56, 56])
SHAPE!!! torch.Size([1, 128, 28, 28])
SHAPE!!! torch.Size([1, 128, 28, 28])
SHAPE!!! torch.Size([1, 128, 28, 28])
SHAPE!!! torch.Size([1, 128, 28, 28])
SHAPE!!! torch.Size([1, 128, 28, 28])
SHAPE!!! torch.Size([1, 256, 14, 14])
SHAPE!!! torch.Size([1, 256, 14, 14])
SHAPE!!! torch.Size([1, 256, 14, 14])
SHAPE!!! torch.Size([1, 256, 14, 14])
SHAPE!!! torch.Size([1, 256, 14, 14])
SHAPE!!! torch.Size([1, 512, 7, 7])
SHAPE!!! torch.Size([1, 512, 7, 7])
SHAPE!!! torch.Size([1, 512, 7, 7])
SHAPE!!! torch.Size([1, 512, 7, 7])
SHAPE!!! torch.Size([1, 512, 7, 7])
SHAPE!!! torch.Size([1, 5, 512, 7, 7])

Traceback (most recent call last):
File "/scratch/ljl5178/Foundation/foundation-cxr/main_medversa.py", line 37, in
seg_mask_2d, seg_mask_3d, output_text = generate_predictions(model, images, context, prompt, modality, task, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature, device)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/scratch/ljl5178/Foundation/foundation-cxr/models/MedVersa_Internal/utils.py", line 321, in generate_predictions
generated_image, seg_mask_2d, seg_mask_3d, output_text = generate(model, images, image_tensor, context, modal, task, num_imgs, prompt, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/scratch/ljl5178/Foundation/foundation-cxr/models/MedVersa_Internal/utils.py", line 288, in generate
seg_mask = task_seg_2d(model, preds, hidden_states, image)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/scratch/ljl5178/Foundation/foundation-cxr/models/MedVersa_Internal/utils.py", line 134, in task_seg_2d
last_feats = model.text2seg_2d_gn(last_feats)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ljl5178/.conda/envs/main/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ljl5178/.conda/envs/main/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/scratch/ljl5178/Foundation/foundation-cxr/models/MedVersa_Internal/medomni/models/medomni.py", line 38, in forward
ret = super().forward(x.type(torch.float32))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ljl5178/.conda/envs/main/lib/python3.12/site-packages/torch/nn/modules/normalization.py", line 313, in forward
return F.group_norm(input, self.num_groups, self.weight, self.bias, self.eps)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ljl5178/.conda/envs/main/lib/python3.12/site-packages/torch/nn/functional.py", line 2965, in group_norm
return torch.group_norm(
^^^^^^^^^^^^^^^^^
RuntimeError: Expected number of channels in input to be divisible by num_groups, but got input of shape [1, 5, 512, 7, 7] and num_groups=16

Sign up or log in to comment