SDPA and FA2 produce different outputs
System Info
Hi,
We noticed a new failure in the CI/CD of kvpress which is related to differences between SDPA and FA2.
Here is my system info:
transformersversion: 4.57.3- Platform: Linux-6.1.123+-x86_64-with-glibc2.39
- Python version: 3.12.3
- Huggingface_hub version: 0.36.0
- Safetensors version: 0.7.0
- Accelerate version: 1.12.0
- PyTorch version (accelerator?): 2.9.1+cu128 (CUDA)
- GPU type: NVIDIA H100 80GB HBM3
Who can help?
@ArthurZucker @Cyrilvallez @Rocketknight1
Information
- [ ] The official example scripts
- [ ] My own modified scripts
Tasks
- [ ] An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - [ ] My own task or dataset (give details below)
Reproduction
The following code produce different outputs for SDPA and FA2, while do_sample is set to False
from transformers import pipeline
model_name = "meta-llama/Llama-3.2-1B-Instruct"
prompt = "Hello, how are you?"
pipe_sdpa = pipeline("text-generation", model=model_name, device_map="auto", dtype="auto", model_kwargs={"attn_implementation":"flash_attention_2"})
pipe_fa2 = pipeline("text-generation", model=model_name, device_map="auto", dtype="auto", model_kwargs={"attn_implementation":"sdpa"})
for _ in range(3):
print(pipe_sdpa(prompt, max_new_tokens=15, do_sample=False)[0]["generated_text"])
print(pipe_fa2(prompt, max_new_tokens=15, do_sample=False)[0]["generated_text"])
Hello, how are you? I'm excited to be here today to talk about a very important topic that
Hello, how are you? I'm excited to be here today to talk to you about something that I
Hello, how are you? I'm excited to be here today to talk about a very important topic that
Hello, how are you? I'm excited to be here today to talk to you about something that I
Hello, how are you? I'm excited to be here today to talk about a very important topic that
Hello, how are you? I'm excited to be here today to talk to you about something that I
Expected behavior
Is this behavior expected ?
Hi @SimJeg, there are small numerical differences between different implementations of attention, particularly in bfloat16 precision. This is because floating-point computations are rounded to a limited amount of precision, so doing operations that are mathematically equivalent in a slightly different order can yield slightly different outputs.
Although their outputs will be extremely similar, it's possible for changing between two attention implementations to slightly change the top logit in some cases, which will change the generation output even when do_sample=False. I suspect that if you run that sequence through the model and check the output logits, the logit for about and to at the point of divergence will be very close, and a small numerical difference will be enough to change which one is larger. This isn't a bug we can fix, and both are valid outputs!
cc @vasqu here as well for FA2, seems like we have a regresion if everything was 1:1 before! Not the first issue about that recently I think, if you can double-check!
@Rocketknight1 @Cyrilvallez thanks for your feedback. There is indeed a regression compared to previous versions (because the test was passing before in our CI/CD), I let you decide if you want to close this issue or not. On our side we just now skip the test we used.
Ok so it's important to note that the environment is: H100, torch=2.9.1, flash attention 2 compiled from source.
I went the extra mile and ran the script, exchanging dtype <-> torch_dtype
from transformers import pipeline
model_name = "meta-llama/Llama-3.2-1B-Instruct"
prompt = "Hello, how are you?"
pipe_sdpa = pipeline("text-generation", model=model_name, device_map="auto", torch_dtype="auto", model_kwargs={"attn_implementation":"flash_attention_2"})
pipe_fa2 = pipeline("text-generation", model=model_name, device_map="auto", torch_dtype="auto", model_kwargs={"attn_implementation":"sdpa"})
sdpa_output = pipe_sdpa(prompt, max_new_tokens=15, do_sample=False)[0]["generated_text"]
fa2_output = pipe_fa2(prompt, max_new_tokens=15, do_sample=False)[0]["generated_text"]
print(sdpa_output)
print(fa2_output)
print(sdpa_output == fa2_output)
from main and then each minor version up to 4.47.0. And in each time, sdpa and fa produced different outputs. So, I'm not sure in which previous version your test passed and in which environment, i.e. especially the torch version --> which might be the culprit here.
So yea, unless I can reproduce the version which produces the same output it's hard to tell. And while libraries try to keep consistent numbers, it's not unexpected to have some small fluctuations between versions which in LLMs can easily lead to slightly different output.
Thanks @vasqu for looking in closer details. If the fluctuations are expected, then I close this issue and will definitely remove our test from kvpress