flash-attention icon indicating copy to clipboard operation
flash-attention copied to clipboard

[AMD] Triton Backend for ROCm #1

Open micmelesse opened this issue 1 year ago • 1 comments

Hi, this is a pr to add a Triton backend to Flash Attention on ROCm. We hope that this pr will be the first in a series of prs to that end. Triton has had support for ROCm for a while now and a Flash Attention Triton backend will allow us to support Flash Attention on both our MI and Navi Machines.

In this pr, we enable major parts of fwd, varlen_fwd and fwd_kvcache. However there are some features missing such as Dropout, Sliding window, Rotary Embedding and Paged Attention. There are also a few miscellaneous bugs. These will all be addressed in subsequent prs. The next pr we plan to file will be support for bwd and varlen_bwd, if we should reprioritize, please let us know.

We have tested this pr here on an MI200 machine with this Triton commit. When testing the Triton Backend for ROCm, we skip testing the backward pass and configs with unsupported features. We also randomly selected about 20% headsizes (d) to test to keep the test times reasonable. We can probably test more if needed. The latest results we have are === 65 failed, 30386 passed, 478321 skipped, 1 warning in 4572.26s (1:16:12) ===. There is clearly more work to be done but we hope that this will make a good start. We have included instructions to run the Triton Backend in the README but the main point is to use export FLASH_ATTENTION_USE_TRITON_ROCM="TRUE" with Triton installed.

Please let us know what we can do on our end to help with this process.

Finally this pr includes work from multiple people besides myself, especial thanks to @vgokhale, @scxiao and @jlgreathouse.

micmelesse avatar Sep 04 '24 12:09 micmelesse

The Gods are Gracious

unclemusclez avatar Sep 15 '24 21:09 unclemusclez

will this work with CDNA 1?

unclemusclez avatar Oct 30 '24 02:10 unclemusclez

will this work with CDNA 1?

The kernels work on any architecture supported by the Triton compiler. Right now the Triton compiler does not officially support MI100 series but most cases should work. We are focused on MI300 and MI200 on the CDNA side.

micmelesse avatar Oct 30 '24 15:10 micmelesse

Hi @tridao

Hope you are doing well. I wanted to check if you have any feedback or suggestions regarding this PR. I've refreshed it to include support for the backward pass and have refactored it to be more modular and easier to review.

We would be happy to add more features or work on performance improvements if needed. If you have any fundamental reservations about adding a Triton backend, please let us know, and we will do everything we can to address them.

Thank you for your time.

micmelesse avatar Oct 30 '24 16:10 micmelesse

Is there anything holding this back?

dtrifiro avatar Nov 22 '24 09:11 dtrifiro

Is there anything holding this back?

We are just waiting for feedback

micmelesse avatar Nov 22 '24 14:11 micmelesse

These features are supported in Fwd for now. We will add them to backward soon. 1.Multi and grouped query attention 2.ALiBi

does this mean the AMD version can not support the alibi_bias before the backwards are added? When using alibi_bias, I always get value instability as nan during the training time.

TerminatorJ avatar Apr 13 '25 22:04 TerminatorJ

We should be putting up a pr in a week or two with backwards support and alibi. I will ping you when it is up.

micmelesse avatar Apr 14 '25 14:04 micmelesse

Second pr is up at https://github.com/Dao-AILab/flash-attention/pull/1610 and the discussion is at https://github.com/Dao-AILab/flash-attention/pull/1561

micmelesse avatar Apr 22 '25 23:04 micmelesse