flash-attention-jax icon indicating copy to clipboard operation
flash-attention-jax copied to clipboard

can I work on making a flax attention function out of this repository?

Open MiladInk opened this issue 3 years ago • 1 comments

Hi lucidrain!

I wanted to use flash attention in one of my projects. I wanted a transformer model that works on sequences as long as 2400 with a batch size of 1000. The original flash attention does not fit in the memory for me. I wanted to use flash attention and found your implementation.

However, I found out I cannot just pass your attention implementation to flax.MultiHeadDotProductAttention here because there the attention_fn needs to be multiheaded, accept mask, dropout_rate, etc.

I was wondering if I could use your flash attention building block and add the required capabilities to it. I am not familiar with flash attention implementation but I am familiar with jax and flax. I was wondering if it is doable without understanding the underlying flash attention. If you think it is possible I can work on it and then create a pull request.

MiladInk avatar Sep 19 '22 03:09 MiladInk

@MiladInk This might be of interest: https://github.com/google/flax/issues/2858.

carlosgmartin avatar Mar 17 '24 02:03 carlosgmartin