Example code throw error only using *-896

#4
by ryj0902 - opened

The code provided in the model card does not cause an error when using the google/paligemma2-10b-pt-448 model, but does cause an error when using the google/paligemma2-10b-pt-896 model.

Output summary:

../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [16386,0,0], thread: [64,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [16386,0,0], thread: [65,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [16386,0,0], thread: [66,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
...(more than 100 lines)

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[2], line 11
      8 input_len = model_inputs["input_ids"].shape[-1]
     10 with torch.inference_mode():
---> 11     generation = model.generate(**model_inputs, max_new_tokens=128, do_sample=False)
     12     generation = generation[0][input_len:]
     13     decoded = processor.decode(generation, skip_special_tokens=True)

File /opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    113 @functools.wraps(func)
    114 def decorate_context(*args, **kwargs):
    115     with ctx_factory():
--> 116         return func(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py:2255, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
   2247     input_ids, model_kwargs = self._expand_inputs_for_generation(
   2248         input_ids=input_ids,
   2249         expand_size=generation_config.num_return_sequences,
   2250         is_encoder_decoder=self.config.is_encoder_decoder,
   2251         **model_kwargs,
   2252     )
   2254     # 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
-> 2255     result = self._sample(
   2256         input_ids,
   2257         logits_processor=prepared_logits_processor,
   2258         stopping_criteria=prepared_stopping_criteria,
   2259         generation_config=generation_config,
   2260         synced_gpus=synced_gpus,
   2261         streamer=streamer,
   2262         **model_kwargs,
   2263     )
   2265 elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
   2266     # 11. prepare beam search scorer
   2267     beam_scorer = BeamSearchScorer(
   2268         batch_size=batch_size,
   2269         num_beams=generation_config.num_beams,
   (...)
   2274         max_length=generation_config.max_length,
   2275     )

File /opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py:3254, in GenerationMixin._sample(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, **model_kwargs)
   3251 model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
   3253 if is_prefill:
-> 3254     outputs = self(**model_inputs, return_dict=True)
   3255     is_prefill = False
   3256 else:

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File /opt/conda/lib/python3.10/site-packages/transformers/models/paligemma/modeling_paligemma.py:530, in PaliGemmaForConditionalGeneration.forward(self, input_ids, pixel_values, attention_mask, position_ids, past_key_values, token_type_ids, cache_position, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, num_logits_to_keep)
    525     labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels)
    527 causal_mask = self._update_causal_mask(
    528     attention_mask, token_type_ids, past_key_values, cache_position, input_ids, inputs_embeds, is_training
    529 )
--> 530 outputs = self.language_model(
    531     attention_mask=causal_mask,
    532     position_ids=position_ids,
    533     past_key_values=past_key_values,
    534     inputs_embeds=inputs_embeds,
    535     use_cache=use_cache,
    536     output_attentions=output_attentions,
    537     output_hidden_states=output_hidden_states,
    538     return_dict=return_dict,
    539     cache_position=cache_position,
    540     num_logits_to_keep=num_logits_to_keep,
    541 )
    543 logits = outputs.logits
    544 loss = None

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File /opt/conda/lib/python3.10/site-packages/transformers/models/gemma2/modeling_gemma2.py:842, in Gemma2ForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, num_logits_to_keep, **loss_kwargs)
    840 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
    841 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
--> 842 outputs = self.model(
    843     input_ids=input_ids,
    844     attention_mask=attention_mask,
    845     position_ids=position_ids,
    846     past_key_values=past_key_values,
    847     inputs_embeds=inputs_embeds,
    848     use_cache=use_cache,
    849     output_attentions=output_attentions,
    850     output_hidden_states=output_hidden_states,
    851     return_dict=return_dict,
    852     cache_position=cache_position,
    853 )
    855 hidden_states = outputs[0]
    856 # Only compute necessary logits, and do not upcast them to float if we are not computing the loss

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File /opt/conda/lib/python3.10/site-packages/transformers/models/gemma2/modeling_gemma2.py:629, in Gemma2Model.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, **flash_attn_kwargs)
    617     layer_outputs = self._gradient_checkpointing_func(
    618         decoder_layer.__call__,
    619         hidden_states,
   (...)
    626         cache_position,
    627     )
    628 else:
--> 629     layer_outputs = decoder_layer(
    630         hidden_states,
    631         position_embeddings=position_embeddings,
    632         attention_mask=causal_mask,
    633         position_ids=position_ids,
    634         past_key_value=past_key_values,
    635         output_attentions=output_attentions,
    636         use_cache=use_cache,
    637         cache_position=cache_position,
    638         **flash_attn_kwargs,
    639     )
    641 hidden_states = layer_outputs[0]
    643 if output_attentions:

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File /opt/conda/lib/python3.10/site-packages/transformers/models/gemma2/modeling_gemma2.py:299, in Gemma2DecoderLayer.forward(self, hidden_states, position_embeddings, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position)
    296 hidden_states = self.input_layernorm(hidden_states)
    298 # Self Attention
--> 299 hidden_states, self_attn_weights = self.self_attn(
    300     hidden_states=hidden_states,
    301     position_embeddings=position_embeddings,
    302     attention_mask=attention_mask,
    303     position_ids=position_ids,
    304     past_key_value=past_key_value,
    305     output_attentions=output_attentions,
    306     use_cache=use_cache,
    307     cache_position=cache_position,
    308 )
    309 hidden_states = self.post_attention_layernorm(hidden_states)
    310 hidden_states = residual + hidden_states

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File /opt/conda/lib/python3.10/site-packages/transformers/models/gemma2/modeling_gemma2.py:224, in Gemma2Attention.forward(self, hidden_states, position_embeddings, attention_mask, past_key_value, cache_position, **kwargs)
    221 if past_key_value is not None:
    222     # sin and cos are specific to RoPE models; cache_position needed for the static cache
    223     cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
--> 224     key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
    226 attention_interface: Callable = eager_attention_forward
    227 if self.config._attn_implementation != "eager":

File /opt/conda/lib/python3.10/site-packages/transformers/cache_utils.py:1717, in HybridCache.update(self, key_states, value_states, layer_idx, cache_kwargs)
   1714 else:
   1715     update_fn = self._static_update
-> 1717 return update_fn(
   1718     cache_position,
   1719     layer_idx,
   1720     key_states,
   1721     value_states,
   1722     k_out,
   1723     v_out,
   1724     k_out.shape[2],
   1725 )

File /opt/conda/lib/python3.10/site-packages/transformers/cache_utils.py:1694, in HybridCache._static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len)
   1693 def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
-> 1694     k_out[:, :, cache_position] = key_states
   1695     v_out[:, :, cache_position] = value_states
   1697     self.key_cache[layer_idx] = k_out

RuntimeError: CUDA error: device-side assert triggered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
Google org

Hi, The model card (google/paligemma2-10b-pt-896) code is running fine when I tried reproducing the error using Nvidia T4*4 GPU (transformer='4.47.1', torch='2.5.1+cu121') setup environment. Please try again and let us know if the issue still persists with some more details like torch version or environment setup details used to run the above code. Please have a look on model output and execution at the below screenshot.
Screenshot 2025-01-13 at 10.55.00 PM.png

Thanks for your quick reply. The code is exactly the same and the problem still persists.

Environment

  • torch: 2.5.1+cu124
  • transformers: 4.48.0
  • GPU: NVIDIA A100-PCIE-40GB

There was no problem with models of different resolutions(448) on the same GPU, same container, and same code.

Google org

I can reproduce with transformers 4.48.0, will take a closer look. Meanwhile, would it be possible for you to to use 4.47.1, which works fine?

@pcuenq very good. It works fine in version 4.47.1.
For now, I'll have to use it under the above conditions.

ryj0902 changed discussion status to closed
Google org

For reference, the fix is going to be part of https://github.com/huggingface/transformers/pull/35681

Sign up or log in to comment