Gemma 2 returns NaN when using default attn (sdpa) with padding
System Info
Python 3.10 Transformers 4.43.3 Linux (Colab notebook)
Who can help?
@ArthurZucker
Information
- [X] The official example scripts
- [ ] My own modified scripts
Tasks
- [ ] An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - [X] My own task or dataset (give details below)
Reproduction
The default gemma 2 2b attn results in NaN for padding tokens. A simple demo can be seen below (also reproduced in this colab notebook):
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")
inputs = tokenizer(["Hello I am a couch", "cats"], return_tensors="pt", padding=True).to('cuda')
with torch.no_grad():
outputs = model(**inputs, output_hidden_states=True)
print(outputs.logits)
This returns the following
tensor([[[-24.3121, -8.7513, -6.9736, ..., -18.3960, -17.4268, -24.3171],
[-16.8873, -4.7767, 5.8828, ..., -9.4981, -9.3307, -16.7723],
[-18.3313, 1.3191, -4.6598, ..., -2.4244, 1.6774, -18.2153],
[-18.9110, -5.8708, -11.7827, ..., -5.6606, -4.2607, -18.8535],
[-20.1359, -8.4194, -15.1834, ..., -13.0231, -11.8288, -19.9716],
[-16.8807, 5.8885, 0.1881, ..., -3.7045, -6.0659, -16.8421]],
[[ nan, nan, nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan]]],
device='cuda:0')
This can be fixed by changing the attn_implementation to anything except sdpa
Expected behavior
Using padding should not result in NaN for normal inputs to gemma 2 2b
Hi @chanind, thanks for reporting the issue!
This is indeed a problem of scaled_dot_product_attention in PyTorch
- https://github.com/pytorch/pytorch/issues/103963
The cause of nan is how softmax is computed over full-masked rows in the attention mask and I hope it will be fixed in future versions of PyTorch, here is a related PR
- https://github.com/pytorch/pytorch/pull/131060
Also, a similar issue has been reported previously
- https://github.com/huggingface/transformers/issues/31035
Besides switching to eager/flash_attnetion_2 you could also try
- Use
float16dtype.
model = AutoModelForCausalLM.from_pretrained(
"google/gemma-2-2b", device_map="auto", torch_dtype=torch.float16
)
- Modify
attn_maskmin value.
As suggested in the above issue, we can modify attn_mask to use another min value instead of torch.finfo(dtype).min, for example, torch.finfo(dtype).min / 2. To apply this, find min_dtype = torch.finfo(dtype).min in gemma modeling file and replace it with torch.finfo(dtype).min / 2.
Meanwhile, we will try to fix it on our side, thanks!
More than this, it's expected as the sdpa path does not support logit soft-capping (For Gemma2).
We do already take into account the sdpa bug when creating the mask @qubvel see here: https://github.com/huggingface/transformers/blob/c1aa0edb48217f416f4bbe6e3a9db1500284513b/src/transformers/models/llama/modeling_llama.py#L1063-L1072
Which should be propagated to Gemma2. (it was not there for some reason my bad here)
Related to #31303
@ArthurZucker thanks for the updated info!
Hi, I have met a problem, when I finetune Gemma2-2b using trainsformers.trainer, I find the lr is always 0, and grad_norm is nan:
so what's wrong? I using the same code to finetune llama3-8b and it works well.
This is my settings:
Same issue here running the code for hooking the activations of the model. Using float16 made it work.
Hey! Make sure you are using eager or flash_attention_2 not sdpa!
Hi, I have met a problem, when I finetune Gemma2-2b using trainsformers.trainer, I find the lr is always 0, and grad_norm is nan:
so what's wrong? I using the same code to finetune llama3-8b and it works well. This is my settings:
hi i have the same issue. How do you solve it? 😊
Hi, I have met a problem, when I finetune Gemma2-2b using trainsformers.trainer, I find the lr is always 0, and grad_norm is nan:
so what's wrong? I using the same code to finetune llama3-8b and it works well. This is my settings:
hi i have the same issue. How do you solve it? 😊
Hi, I just use eager instead of sdpa like this: model = AutoModelForCausalLM.from_pretrained(args.prune_model_path, trust_remote_code=True, device_map=device_map, attn_implementation="eager" )
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
so what's wrong? I using the same code to finetune llama3-8b and it works well. This is my settings: 