keras-nlp icon indicating copy to clipboard operation
keras-nlp copied to clipboard

Add Flash Attention 2 support

Open rajveer43 opened this issue 2 years ago • 5 comments

Is your feature request related to a problem? Please describe. Flash Attention 2 is a library that provides attention operation kernels for faster and more memory efficient inference and training:

Describe the solution you'd like (https://github.com/Dao-AILab/flash-attention)

rajveer43 avatar Oct 05 '23 06:10 rajveer43

Thanks! We definitely want this, and are actively looking at how to best support this in a cross platform way for jax, torch and tf.

Ideally this is something we can handle at the compiler level of our stack (e.g. with XLA), as we would love to avoid writing custom kernels for specific sets of hardware in this repo.

Will use issue this post any updates as our plan develops.

mattdangerw avatar Oct 05 '23 21:10 mattdangerw

Thanks!

rajveer43 avatar Oct 06 '23 05:10 rajveer43

I think flash-attention being added to XLA recently, is there anyway to check if the current implementation of this repo can be lowered to flash-attention rewritter.?

dathudeptrai avatar Mar 02 '24 10:03 dathudeptrai

I am interested in contributing to it.

shanky-kapoor avatar Mar 04 '24 19:03 shanky-kapoor

FYI, there is now an unofficial JAX binding for FlashAttention.

AndreasMadsen avatar May 18 '24 00:05 AndreasMadsen