transformers
transformers copied to clipboard
Community contribution: Adding Flash Attention 2 support for more architectures
Feature request
Flash Attention 2 is a library that provides attention operation kernels for faster and more memory efficient inference and training: https://github.com/Dao-AILab/flash-attention
Let's try to add Flash Attention 2 support for more architectures! Currently supported architectures are
- [x] Llama
- [x] Falcon
It would be great to add the support for more architectures such as
- [x] Bark
- [x] Bart
- [ ] BERT | @sorenmc
- [ ] CLIP https://github.com/huggingface/transformers/pull/27444/
- [x] DistilBERT
- [x] GPT-2
- [x] GPT-J
- [x] GPTBigCode (Starcoder) | @susnato
- [x] GPT-neo
- [x] GPT-neo-x | @younesbelkada #26463
- [x] OPT | @susnato #26414
- [x] Llava
- [x] VipLlava
- [x] mBART
- [x] Mistral
- [x] Mixtral
- [ ] MPT | @rajveer43
- [ ] T5
- [ ] Persimmon | @jeromeku
- [x] Phi
- [x] Whisper
- [x] Qwen2
... and many more
Adding this feature would require to follow the same protocol as in https://github.com/huggingface/transformers/pull/25598
. First create a new module inside the corresponding modeling file termed as xxxFlashAttention
that inherits from xxxAttention
and override the foward method to use the public methods from flash-attn
. Make sure to have access to a GPU that supports Flash Attention 2.
Given the slight challenge of the issue, labelling it as a good second issue!
If you are interested to take up the challenge, comment below with the architecture name you want to integrate and open a PR!
Once you open a PR, feel free to ping @LysandreJik @ArthurZucker @amyeroberts @younesbelkada @fxmarty @SunMarc @pacman100 for a review
Motivation
Making LLMs more memory efficient and faster !
Your contribution
Reviewing PRs and possibly adding the support for more models
Hi @younesbelkada - I want to work on adding Flash Attention 2 support for GPTBigCode (Starcoder). Can I take this task? Can you please assign this task to me?
Will definitely take a look next week Great to see it merged now 💪
I would like to work on MPT
@younesbelkada
I would like to work on OPT
.
Is it possible to add FlashAttention2 to GPT2 models?
@sahilbhosale63 @flozi00 @rajveer43 @susnato thanks very much for your interest! Indeed it would be great if you could help us! Before assigning you to this issue can you confirm you have access to a GPU that does support Flash Attention 2: https://github.com/Dao-AILab/flash-attention#installation-and-features in order to be able to run the tests ? @ZeusFSX , yes I think that it is possible, I'll update the list accodingly
@younesbelkada Yes I have
OK perfect, I will assign you to MPT ! Feel free to let me know if you need any help or if you have any question, as a starting point, I would recommend to have a look at #25598 and see if you can replicate the PR for MPT. For running flash attention tests you can just run (once PR is ready):
RUN_SLOW=1 pytest -m flash_attn_test tests/models/mpt/
@younesbelkada yes I have.
Thanks @susnato , perfect then, let me know whenever you start the PR and if you have any question ! Check out my instructions above for more details
@younesbelkada Unfortunately, My GPU is not supported
OK perfect, I will assign you to MPT ! Feel free to let me know if you need any help or if you have any question, as a starting point, I would recommend to have a look at #25598 and see if you can replicate the PR for MPT. For running flash attention tests you can just run (once PR is ready):
RUN_SLOW=1 pytest -m flash_attn_tests tests/models/mpt/
Sure I will work on it!
@younesbelkada Would like to work on Persimmon. I have access to A4000, A5000, and A6000, which I believe should be compatible with FA2.
Perfect sounds great, thanks for your help, I will assign you to Persimmon !
Since @sahilbhosale63 is not working on GPTBigCode (Starcoder)
(as he said here) can I take that @younesbelkada?
Yes no problem, thanks very much for proposing your help on this ! As a starting point you can have a look at @pacman100 's implementation here: https://github.com/pacman100/DHS-LLM-Workshop/blob/main/personal_copilot/training/starcoder_flash_attn_monkey_patch.py
@younesbelkada I would like to implement it for BERT if it hasn't already been done? A lot of the models topping MTEB are still relying on this architecture! I have tested that i can run flash attention 2 on my nvidia geforce RTX 3060 TI.
Awesome, thanks a lot for your help, ok I will assign you to BERT then!
Hi everyone, I would like to help implement this with GPT2 if you want.
@younesbelkada
I have a working version for Persimmon
that passes the flash_attn_v2
tests except for generate_padding_right
as the original PersimmonFlashAttention
does not have padding_mask
as a kw input (as opposed to the Llama
and Falcon
flash implementations). Is this something that needs to be changed in both Persimmon Flash v1 and v2?
Also, any plans on incorporating additional optimizations, e.g., Flash Attention
repo has fused layers for dense
, rotary
, and layer norm
for faster training; and Triton kernels, more generally? Happy to investigate more!
Also, would like to help with Mistral-7b
(just released). They use xformers
memory efficient attention
in their released implementation but also mention Tri Dao's FA in the blogpost.
Hi @DougTrajano Awesome! Can you confirm you have access to a hardware that is supported by FA-2?
@jeromeku awesome thanks! Can you move forward for Persimmon by opening a PR so that I can have a look?
Also, any plans on incorporating additional optimizations, e.g., Flash Attention repo has fused layers for dense, rotary, and layer norm for faster training; and Triton kernels, more generally? Happy to investigate more!
If that is something that can nicely fit into the API without any breaking behaviour that would be great !
Also, would like to help with Mistral-7b (just released). They use xformers memory efficient attention in their released implementation but also mention Tri Dao's FA in the blogpost.
I think Mistral's attention has been released in the latest version of FA-2 --> Would you be happy to open a PoC PR so that I can play with it and see what we can do?
Again thanks a lot!
Hi @jeromeku I had to check internally for Mistral, given the very recent release and the urgency, we'll take this over (https://github.com/huggingface/transformers/pull/26464); if you have started a PR, I'm very happy to start from it or to add you as a co-author to the PR ! We might also refactor things a bit to support Local attention introduced by Mistral, so that needs further investigation, I'll keep you posted
@younesbelkada what is the expected deadline to complete MPT
, I have other issues to tackle on so I can plan accordingly
Hi @younesbelkada , I am talking this up for GPT-neo
.
Awesome @susnato ! Thanks ! @rajveer43 thanks for taking up MPT, will check it out!
Hi @DougTrajano Awesome! Can you confirm you have access to a hardware that is supported by FA-2?
Yes, I'll work on AWS SageMaker.
Would love to take on GPT2!
Thanks for confirming @DougTrajano ! @marcasty thanks a lot for your interest, @DougTrajano has taken up GPT2, would be happy taking another model? 🙏 Can you also confirm you have access to a hardware that support FA-2 ?
Hi @younesbelkada, I am taking this up for DistillBERT
.
@younesbelkada what about T5? I have access to compatible hardware