Running into a tensor size error

#33
by AaronVogler - opened

When I run the model on CPU, I get the following error:

Fetching 50 files: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 15980.74it/s]
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 200.15it/s]
Device set to use cpu
Traceback (most recent call last):
  File "/home/inf3rnus/miniconda3/envs/llama4/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/inf3rnus/miniconda3/envs/llama4/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/inf3rnus/.vscode-server/extensions/ms-python.debugpy-2025.6.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/__main__.py", line 71, in <module>
    cli.main()
  File "/home/inf3rnus/.vscode-server/extensions/ms-python.debugpy-2025.6.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 501, in main
    run()
  File "/home/inf3rnus/.vscode-server/extensions/ms-python.debugpy-2025.6.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 351, in run_file
    runpy.run_path(target, run_name="__main__")
  File "/home/inf3rnus/.vscode-server/extensions/ms-python.debugpy-2025.6.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 310, in run_path
    return _run_module_code(code, init_globals, run_name, pkg_name=pkg_name, script_name=fname)
  File "/home/inf3rnus/.vscode-server/extensions/ms-python.debugpy-2025.6.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 127, in _run_module_code
    _run_code(code, mod_globals, init_globals, mod_name, mod_spec, pkg_name, script_name)
  File "/home/inf3rnus/.vscode-server/extensions/ms-python.debugpy-2025.6.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 118, in _run_code
    exec(code, run_globals)
  File "/home/inf3rnus/test_llama4.py", line 13, in <module>
    output = pipe("Roses are red,", max_new_tokens=200)
  File "/home/inf3rnus/miniconda3/envs/llama4/lib/python3.9/site-packages/transformers/pipelines/text_generation.py", line 287, in __call__
    return super().__call__(text_inputs, **kwargs)
  File "/home/inf3rnus/miniconda3/envs/llama4/lib/python3.9/site-packages/transformers/pipelines/base.py", line 1379, in __call__
    return self.run_single(inputs, preprocess_params, forward_params, postprocess_params)
  File "/home/inf3rnus/miniconda3/envs/llama4/lib/python3.9/site-packages/transformers/pipelines/base.py", line 1386, in run_single
    model_outputs = self.forward(model_inputs, **forward_params)
  File "/home/inf3rnus/miniconda3/envs/llama4/lib/python3.9/site-packages/transformers/pipelines/base.py", line 1286, in forward
    model_outputs = self._forward(model_inputs, **forward_params)
  File "/home/inf3rnus/miniconda3/envs/llama4/lib/python3.9/site-packages/transformers/pipelines/text_generation.py", line 385, in _forward
    output = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
  File "/home/inf3rnus/miniconda3/envs/llama4/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/inf3rnus/miniconda3/envs/llama4/lib/python3.9/site-packages/transformers/generation/utils.py", line 2463, in generate
    result = self._sample(
  File "/home/inf3rnus/miniconda3/envs/llama4/lib/python3.9/site-packages/transformers/generation/utils.py", line 3429, in _sample
    outputs = self(**model_inputs, return_dict=True)
  File "/home/inf3rnus/miniconda3/envs/llama4/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/inf3rnus/miniconda3/envs/llama4/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/inf3rnus/miniconda3/envs/llama4/lib/python3.9/site-packages/transformers/models/llama4/modeling_llama4.py", line 1022, in forward
    outputs = self.model(
  File "/home/inf3rnus/miniconda3/envs/llama4/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/inf3rnus/miniconda3/envs/llama4/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/inf3rnus/miniconda3/envs/llama4/lib/python3.9/site-packages/transformers/models/llama4/modeling_llama4.py", line 701, in forward
    layer_outputs = decoder_layer(
  File "/home/inf3rnus/miniconda3/envs/llama4/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/inf3rnus/miniconda3/envs/llama4/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/inf3rnus/miniconda3/envs/llama4/lib/python3.9/site-packages/transformers/models/llama4/modeling_llama4.py", line 435, in forward
    attention_states, self_attn_weights = self.self_attn(
  File "/home/inf3rnus/miniconda3/envs/llama4/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/inf3rnus/miniconda3/envs/llama4/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/inf3rnus/miniconda3/envs/llama4/lib/python3.9/site-packages/transformers/models/llama4/modeling_llama4.py", line 379, in forward
    attn_output, attn_weights = attention_interface(
  File "/home/inf3rnus/miniconda3/envs/llama4/lib/python3.9/site-packages/transformers/integrations/sdpa_attention.py", line 54, in sdpa_attention_forward
    attn_output = torch.nn.functional.scaled_dot_product_attention(
RuntimeError: The size of tensor a (8192) must match the size of tensor b (205) at non-singleton dimension 3

Any idea how to fix this? This is using the provided example:

from transformers import pipeline
import torch

model_id = "meta-llama/Llama-4-Scout-17B-16E"

pipe = pipeline(
    "text-generation",
    model=model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)

output = pipe("Roses are red,", max_new_tokens=200)

a = 2

FYI it's because I was using torch.bfloat16 which is only supported on GPUs. However I still get the same size error when running it only on the CPU.

I think this is regression introduced with 4.51.1. And I could not make it work with images (in 4.51.0 it worked with attn_implementation="eager"). It seems to be fixed in 4.51.2.

Your need to confirm your account before you can post a new comment.

Sign up or log in to comment