keras-nlp
keras-nlp copied to clipboard
Add Flash Attention 2 support
Is your feature request related to a problem? Please describe. Flash Attention 2 is a library that provides attention operation kernels for faster and more memory efficient inference and training:
Describe the solution you'd like (https://github.com/Dao-AILab/flash-attention)
Thanks! We definitely want this, and are actively looking at how to best support this in a cross platform way for jax, torch and tf.
Ideally this is something we can handle at the compiler level of our stack (e.g. with XLA), as we would love to avoid writing custom kernels for specific sets of hardware in this repo.
Will use issue this post any updates as our plan develops.
Thanks!
I think flash-attention being added to XLA recently, is there anyway to check if the current implementation of this repo can be lowered to flash-attention rewritter.?
I am interested in contributing to it.
FYI, there is now an unofficial JAX binding for FlashAttention.