x-transformers
x-transformers copied to clipboard
Implement/Integrate Flash-Attention
Hi Lucidrains,
I have stumbled across this paper that might be of interest to you.
Given how transformers are memory/compute hungry I think this will be highly useful for the repo.
Here is the reference: https://arxiv.org/abs/2205.14135
Here is the repo: https://github.com/HazyResearch/flash-attention
Cheers!
Dan
Hi Petru-Daniel
Yes indeed, it looks like a breakthrough moment
I've been playing around with it here and here and it works as advertised
I'll probably wait for Tri to whittle away a few more issues before I start using it for all my new transformers projects
In particular, it still lacks the ability to take in attention biases, and support for causal masking when the query and key / value lengths differ
Hi @lucidrains any update?
@lucidrains
Given that pytorch 2.0 released the new F.scaled_dot_product_attention
would it make sense to include it in this repo as an optional parameter?
One could think about some asserts to tell the user that certain attention bias and masking methods are not supported if flash attention is enabled.
Also, I have one more question: Would it be possible to adapt rotary_xpos=True
so that it is also compatible with full self-attention (bidirectional, no causal masking)? Similarly, to how the alibi bias was adapted to full self-attention?
PS: Have you checked out the memory efficient attention implementation in the x-former repo (it can be used in pytorch 2.0)? They seem to offer the option to also add an attention bias to the function (in contrast to pytorch F.scaled_dot_product_attention
), see here: https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention
That would maybe allow to use extrapolation emthods like alibi together with memory efficient attention.
it is done https://github.com/lucidrains/x-transformers#flash-attention