flash-attention-jax
flash-attention-jax copied to clipboard
can I work on making a flax attention function out of this repository?
Hi lucidrain!
I wanted to use flash attention in one of my projects. I wanted a transformer model that works on sequences as long as 2400 with a batch size of 1000. The original flash attention does not fit in the memory for me. I wanted to use flash attention and found your implementation.
However, I found out I cannot just pass your attention implementation to flax.MultiHeadDotProductAttention here because there the attention_fn needs to be multiheaded, accept mask, dropout_rate, etc.
I was wondering if I could use your flash attention building block and add the required capabilities to it. I am not familiar with flash attention implementation but I am familiar with jax and flax. I was wondering if it is doable without understanding the underlying flash attention. If you think it is possible I can work on it and then create a pull request.
@MiladInk This might be of interest: https://github.com/google/flax/issues/2858.