Fix llama model sdpa attention forward function masking bug when output_attentions=True
What does this PR do?
Very simple fix to a nasty issue I have recently encountered. Due to its simplicity, I opened a PR directly without raising an issue first to avoid redundancy. Please, let me know if I should also raise an issue, and I'll do that right away.
Description
When output_attentions is True, sdpa implementation's forward method calls the eager implementation's forward method. However, a None mask is still returned if sdpa's 'AttentionMaskConverter._ignore_causal_mask_sdpa' returns true (which occurs whenever the input is unmasked, as sdpa would defer the causal masking to the sdpa Pytorch implementation). This inconsistency causes the model to run the eager implementation with no causal attention mask if the original input is unmasked (e.g., if a single input sequence is encoded or all encoded input sequences have the same length) and requires_attn=True.
- [x] Did you read the contributor guideline, Pull Request section?
- [x] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
Tagging @ArthurZucker and @younesbelkada
A minimal example of this erroneous behavior can be reproduced via:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model_name = "meta-llama/Meta-Llama-3-8B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name,
device_map='cuda',
torch_dtype=torch.bfloat16
)
tokenizer.pad_token_id = tokenizer.eos_token_id
inputs = tokenizer(["Today is the day I went to the store and ..."],
return_tensors="pt").to('cuda')
expanded_batch_size = 1
outputs = model.generate(
input_ids = inputs['input_ids'].expand(expanded_batch_size, -1),
attention_mask = inputs['attention_mask'].expand(expanded_batch_size, -1),
do_sample=False,
max_new_tokens=5,
return_dict_in_generate=True,
)
input_length = inputs.input_ids.shape[1]
sequences= outputs.sequences
for sequence in sequences:
decoded_sequence = tokenizer.decode(sequence)
print(decoded_sequence)
# separator
print('-'*20)
outputs = model.generate(
input_ids = inputs['input_ids'].expand(expanded_batch_size, -1),
attention_mask = inputs['attention_mask'].expand(expanded_batch_size, -1),
do_sample=False,
max_new_tokens=5,
return_dict_in_generate=True,
output_attentions=True, # ?!
)
input_length = inputs.input_ids.shape[1]
sequences= outputs.sequences
# garbage generated outputs since no masking is applied
for sequence in sequences:
decoded_sequence = tokenizer.decode(sequence)
print(decoded_sequence)
Great catch.
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)line 1127 needs to be ignored as well.- we need to add your small example script as a test! 🤗
@ArthurZucker Thanks for reviewing my pull request and all your work in maintaining this awesome repo! :) Regarding your comments:
- Done.
- Let me know if you would like me to make a small testing script for this bug myself! (i.e., check that generated outputs with the 'eager' implementation match the generated outputs with output_attentions=True, although inherent stochasticity in the GPU kernels might make it difficult to always get 100% consistent results).
p.s. There seem to be some CircleCI tests failing on the main branch... which are now failing after I merged.
For 2. the test is already implemented, but I don't think it tests output_attention=True. It probably a matter of adding a parametrized. See this file here: (and the generate tests) https://github.com/huggingface/transformers/blob/main/tests/test_modeling_common.py#L3590.
Potentially adding output_attention to make sure sdpa with output attention matches eager with or without (which it is supposed to!)
Feel free to rebase it might be fixed on main / be flaky
Feel free to rebase it might be fixed on main / be flaky
Just did :)
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
@ArthurZucker Let me know if you think this fix is ready for merging, or if you'd like to add the tests to the same PR!
Would be nice to just add the test in this PR 😉
Alright - I made the addition of output_attentions=True to the sdpa equivalence test, as you suggested ;) (Black code re-formatting seems to have messed up the diff, but the changes are minimal...)
@ArthurZucker - Let me know if there are any outstanding issues or if there is something else missing before merging ^^
(Merging once the CIs are all green!)
@ArthurZucker thanks for your suggestions! I also propagated the same changes to the new jetmoe model. All default checks are now passing ^^
THanks for the fix
@Aladoro thank you for detecting the issue and making transformers better for all of us 💛