flash-attention
flash-attention copied to clipboard
Triton version is faster in both forward and backward when head dim is 64 but slower in both when head dim is 128
Thanks for your amazing work!
In the current version of the triton implementation https://github.com/triton-lang/triton/blob/main/python/tutorials/06-fused-attention.py, Triton version is faster in both forward and backward than the cutlass implementation of flash-attention 2 using the default config in their code (head dim is 64). However, when changing head dim to 128, the triton version performs worse in both forward and backward.
Do you know what might be causing this? And in the cutlass implementation, are there any specific optimizations to prevent significant performance degradation when the head dimension is 128?
On A100 or H100? If H100 then the Triton version uses new instructions on H100 but FA2 doesn't. You should try FA3 if you're on H100.
Thanks for your reply! The testing environment is:
- GPU: A100
- flash-attn: 2.6.3
- triton: 3.1.0
What speed (TFLOPS) do you get?
This is the result of https://github.com/triton-lang/triton/blob/main/python/tutorials/06-fused-attention.py
BATCH, N_HEADS, HEAD_DIM = 4, 32, 64
BATCH, N_HEADS, HEAD_DIM = 4, 32, 128
You can try FA3 too, which runs on A100 now. Btw triton bwd does not support causal=False, when you call it w causal=False it still runs w causal=True. You can check the output.
You can try FA3 too, which runs on A100 now. Btw triton bwd does not support causal=False, when you call it w causal=False it still runs w causal=True. You can check the output.
Thanks for pointing out my mistake!
Do you know what causes Triton's slowdown at head_dim=128? In contrast, the official FA2 performs well - is this due to targeted optimizations?
Most likely register spilling if I have to guess. You can try smaller block sizes to see if that helps.
Most likely register spilling if I have to guess. You can try smaller block sizes to see if that helps.
Changing BLOCK_M from 128 to 64 will lead to worse performance:
BLOCK_M 128:
BLOCK_M 64:
yeah then idk
yeah then idk
Thanks for your reply. I'll dig deeper into this.
yeah then idk
Thanks for your reply. I'll dig deeper into this.
Any progress on it? I'm also curious about it.
yeah then idk
Thanks for your reply. I'll dig deeper into this.
Any progress on it? I'm also curious about it. I haven't performed a kernel analysis by ncu yet. But this performance gap can be reduced using triton 3.2.