flash-attention
flash-attention copied to clipboard
FlashAttention forward support for Turing
Hi, I tried my hands on implementing flash attention forward pass for the Turing architecture. This is the repo:
https://github.com/ssiu/flash-attention-turing
As this is still an early implementation, it only supports:
head_dim = 128- vanilla attention i.e. no masking
seq_lenmust be divisible by 128
For batch_size = 4, num_heads = 32, head_dim = 128, our implementation is currently around 2x faster than Pytorch's F.scaled_dot_product_attention, which calls Memory-Efficient Attention in the backend. This was tested for T4.
For long sequences it reaches around 63% compute throughput.
Thanks!
Wow this is great work!
Thanks for this, I has been trying to find a way to get flash attention working on T4 GPU.
You help me so much.
This is great work! I am looking for this. Is there any new progress?
Hi @lwllvyb, yes I will be continuing to work on this.
Great, praise your work!
@ssiu This is awesome! I can’t wait for this to be official. You’re doing incredible work - keep going, you’re amazing! 🙌💪