gemma icon indicating copy to clipboard operation
gemma copied to clipboard

"RuntimeError: p.attn_bias_ptr is not correctly aligned" when using ``` in a prompt with images

Open FredrikNoren opened this issue 9 months ago • 19 comments

I'm using google/gemma-3-4b-it through huggingface transformers to process videos, and the following prompt works:

[
            {"role": "user", "content": [
                *[{"type": "image", "image": img} for img in images[0:20]],
                { "type": "text", "text": "Describe this video in detail" }
            ] }
        ]

But if I change it to this I get a crash (I was trying to add a typescript json description but this is a simplified example):

[
            {"role": "user", "content": [
                *[{"type": "image", "image": img} for img in images[0:20]],
                { "type": "text", "text": "Describe this video \n```\nin\n```\n detail" }
            ] }
        ]

The crash I'm getting is:

  File "/tmp/ray/session_2025-03-12_08-22-53_750470_309/runtime_resources/working_dir_files/_ray_pkg_684ab80be417364b/agent/policies/gemma3.py", line 54, in run_messages
    outputs = self.model.generate(**inputs, max_new_tokens=512)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ray/anaconda3/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/ray/anaconda3/lib/python3.12/site-packages/transformers/generation/utils.py", line 2250, in generate
    result = self._sample(
             ^^^^^^^^^^^^^
  File "/home/ray/anaconda3/lib/python3.12/site-packages/transformers/generation/utils.py", line 3241, in _sample
    outputs = model_forward(**model_inputs, return_dict=True)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ray/anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ray/anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ray/anaconda3/lib/python3.12/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/ray/anaconda3/lib/python3.12/site-packages/transformers/models/gemma3/modeling_gemma3.py", line 1352, in forward
    outputs = self.language_model(
              ^^^^^^^^^^^^^^^^^^^^
  File "/home/ray/anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ray/anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ray/anaconda3/lib/python3.12/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/ray/anaconda3/lib/python3.12/site-packages/transformers/models/gemma3/modeling_gemma3.py", line 976, in forward
    outputs = self.model(
              ^^^^^^^^^^^
  File "/home/ray/anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ray/anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ray/anaconda3/lib/python3.12/site-packages/transformers/models/gemma3/modeling_gemma3.py", line 754, in forward
    layer_outputs = decoder_layer(
                    ^^^^^^^^^^^^^^
  File "/home/ray/anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ray/anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ray/anaconda3/lib/python3.12/site-packages/transformers/models/gemma3/modeling_gemma3.py", line 443, in forward
    hidden_states, self_attn_weights = self.self_attn(
                                       ^^^^^^^^^^^^^^^
  File "/home/ray/anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ray/anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ray/anaconda3/lib/python3.12/site-packages/transformers/models/gemma3/modeling_gemma3.py", line 365, in forward
    attn_output, attn_weights = attention_interface(
                                ^^^^^^^^^^^^^^^^^^^^
  File "/home/ray/anaconda3/lib/python3.12/site-packages/transformers/integrations/sdpa_attention.py", line 54, in sdpa_attention_forward
    attn_output = torch.nn.functional.scaled_dot_product_attention(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: p.attn_bias_ptr is not correctly aligned

FredrikNoren avatar Mar 13 '25 09:03 FredrikNoren

This is on a A100 GPU, and I'm using torch 2.6.0.

I'm also getting this error randomly with different image inputs; I'm processing videos in a loop and some of them it'll fail for, even though the rest of the prompt is identical.

FredrikNoren avatar Mar 13 '25 10:03 FredrikNoren

Same problem, I'm using multi-image multi-turn generation with huggingface, would appreciate any help here! All model types (4b, 12b, 27b) have this problem. But this problem is random and contingent on image.

BiEchi avatar Mar 15 '25 05:03 BiEchi

@FredrikNoren did you try using the gemma repo instead of HF? Is it also the case for the original repo or is it only a problem with HF?

BiEchi avatar Mar 15 '25 15:03 BiEchi

I just tried this:

        self.model = Gemma3ForConditionalGeneration.from_pretrained(...,
            attn_implementation="eager"
        ).eval()

maybe you can also try manual attn implementation (eager) first

BiEchi avatar Mar 15 '25 21:03 BiEchi

getting the same issue without any image inputs after fine-tuning with unsloth (again, somewhat randomly and unpredictably), has anyone got this to work?

amackenzie1 avatar Mar 15 '25 22:03 amackenzie1

Same Here.

kennymckormick avatar Mar 17 '25 07:03 kennymckormick

You can try to add option do_pan_and_scan=True in processor.apply_chat_template

alexchenfeng avatar Mar 17 '25 07:03 alexchenfeng

I got same error randomly for both gemma-3-4b-it and gemma-3-12b-it, even without images. I tried the suggested do_pan_and_scan=True with no effect, while attn_implementation="eager" possibly fixes the issue, but soon causes CUDA out of memory (on RTX A6000 with 48GB)

alex-me avatar Mar 17 '25 18:03 alex-me

@BiEchi I've only tried HF

FredrikNoren avatar Mar 18 '25 08:03 FredrikNoren

@alex-me I also tried attn_implementation="eager" when loading the model, I still get this error on a specific input. I am using A100

ServientShao avatar Mar 18 '25 20:03 ServientShao

Hey, I find a workaround here,

import torch torch.backends.cuda.enable_mem_efficient_sdp(False) torch.backends.cuda.enable_flash_sdp(False) torch.backends.cuda.enable_math_sdp(True)

but this may slowdown performance

ServientShao avatar Mar 18 '25 21:03 ServientShao

Add pad_to_multiple_of=8, most tokenizers (and Gemma’s processor) support pad_to_multiple_of. That padding eliminates the misalignment so FlashAttention can run without crashing — and it uses far less memory than the math fallback.

inputs = tokenizer( prompts, return_tensors="pt", padding="longest", pad_to_multiple_of=8, # <-- ensures seq_len % 8 == 0 ).to(model.device)

#add processor.tokenizer.padding_side = "left"

hihiruby avatar Mar 19 '25 09:03 hihiruby

Just add this, it's solved

fuchao01 avatar Mar 20 '25 06:03 fuchao01

I solved it by changing the prompt. The following line (in python) has crashed my gemma-3-12b-it because % and following d were recognized as a control sequence by python linter.

  • "10% discount" → "ten percent discount"

shigabeev avatar Mar 26 '25 03:03 shigabeev

Could you please confirm if this issue is resolved for you with the above comment ? Please feel free to close the issue if it is resolved ?

Thank you.

Gopi-Uppari avatar Apr 25 '25 01:04 Gopi-Uppari

I still run into this issue with even with the Tokenizer usign pad_to_multiple_of=8. I could only circumvent it with eager-attn instead of flash-attn

Banthafutter avatar Apr 26 '25 21:04 Banthafutter

Is there a definitive solution to this problem yet?

or at least could someone please explain what is causing this error?

Saisriram01 avatar Jun 26 '25 05:06 Saisriram01

I solved upgrading transformers: pip install transformers==4.53.2

nazarenodefrancesc avatar Jul 16 '25 08:07 nazarenodefrancesc

Hi @FredrikNoren ,

Could you please confirm whether the above issue is resolved or not, please let me know if you requires any further assistance. Thanks for your continuous interest and patience.

Thanks.