FlagAttention
FlagAttention copied to clipboard
A collection of memory efficient attention operators implemented in the Triton language.
support grouped query attention(GQA) for flash_attn(related kernels: fwd, bwd, split_kv, total_attention) The GQA paper > Ainslie, Joshua, James Lee-Thorp, Michiel de Jong, Yury Zemlyanskiy, Federico Lebrón, and Sumit Sanghai. “GQA:...
Adds bias to attention. Many tests fail for me (that's why i'm adding draft PR), especially the BTHD and longer sequence ones (my GPU is 12Gb) but manual pytorch tests...
The pytorch base implementation of [`scaled_dot_product_attention`](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention) provides dropout as an arg. Fusing it into the triton kernel would replicate that functionality, as dropout is applied to the attention scores, not...
Thanks for your amazing work! In the latest version of Triton's official tutorial, when directly running the Flash Attention 2 Python script (head dim is 64), Triton's implementation is faster...
https://github.com/FlagOpen/FlagAttention/blob/548382d704db9310e7613a542bb81efc96ef3890/src/flag_attn/flash.py#L215 If change this line to 64, 64, 1, 4, it has errors and failed test. ``` root@321baeafd609:/workspace/FlagAttention# pytest /workspace/FlagAttention/tests/flag_attn/test_flash_attention.py -x ============================================= test session starts ============================================== platform linux -- Python...
#### 1. switch off nvidia mma accelearte in `triton` project `third_party/nvidia/backend/compiler.py`, delete `passes.ttgpuir.add_accelerate_matmul(pm)` to swich off mma, use fma instead. #### 2. use fp16 as dot accumulate data type in...