flash-attention icon indicating copy to clipboard operation
flash-attention copied to clipboard

Triton version is faster in both forward and backward when head dim is 64 but slower in both when head dim is 128

Open SonicZun opened this issue 7 months ago • 10 comments
trafficstars

Thanks for your amazing work!

In the current version of the triton implementation https://github.com/triton-lang/triton/blob/main/python/tutorials/06-fused-attention.py, Triton version is faster in both forward and backward than the cutlass implementation of flash-attention 2 using the default config in their code (head dim is 64). However, when changing head dim to 128, the triton version performs worse in both forward and backward.

Do you know what might be causing this? And in the cutlass implementation, are there any specific optimizations to prevent significant performance degradation when the head dimension is 128?

SonicZun avatar Mar 27 '25 02:03 SonicZun

On A100 or H100? If H100 then the Triton version uses new instructions on H100 but FA2 doesn't. You should try FA3 if you're on H100.

tridao avatar Mar 27 '25 03:03 tridao

Thanks for your reply! The testing environment is:

  • GPU: A100
  • flash-attn: 2.6.3
  • triton: 3.1.0

SonicZun avatar Mar 27 '25 04:03 SonicZun

What speed (TFLOPS) do you get?

tridao avatar Mar 27 '25 04:03 tridao

Image This is the result of https://github.com/triton-lang/triton/blob/main/python/tutorials/06-fused-attention.py BATCH, N_HEADS, HEAD_DIM = 4, 32, 64 BATCH, N_HEADS, HEAD_DIM = 4, 32, 128

SonicZun avatar Mar 27 '25 05:03 SonicZun

You can try FA3 too, which runs on A100 now. Btw triton bwd does not support causal=False, when you call it w causal=False it still runs w causal=True. You can check the output.

tridao avatar Mar 27 '25 14:03 tridao

You can try FA3 too, which runs on A100 now. Btw triton bwd does not support causal=False, when you call it w causal=False it still runs w causal=True. You can check the output.

Thanks for pointing out my mistake!

Do you know what causes Triton's slowdown at head_dim=128? In contrast, the official FA2 performs well - is this due to targeted optimizations?

SonicZun avatar Mar 27 '25 15:03 SonicZun

Most likely register spilling if I have to guess. You can try smaller block sizes to see if that helps.

tridao avatar Mar 27 '25 15:03 tridao

Most likely register spilling if I have to guess. You can try smaller block sizes to see if that helps.

Changing BLOCK_M from 128 to 64 will lead to worse performance: BLOCK_M 128: Image

BLOCK_M 64: Image

SonicZun avatar Mar 27 '25 15:03 SonicZun

yeah then idk

tridao avatar Mar 28 '25 03:03 tridao

yeah then idk

Thanks for your reply. I'll dig deeper into this.

SonicZun avatar Mar 28 '25 03:03 SonicZun

yeah then idk

Thanks for your reply. I'll dig deeper into this.

Any progress on it? I'm also curious about it.

daoxian avatar Apr 29 '25 06:04 daoxian

yeah then idk

Thanks for your reply. I'll dig deeper into this.

Any progress on it? I'm also curious about it. I haven't performed a kernel analysis by ncu yet. But this performance gap can be reduced using triton 3.2.

SonicZun avatar Apr 29 '25 06:04 SonicZun