transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Community contribution: Adding Flash Attention 2 support for more architectures

Open younesbelkada opened this issue 1 year ago • 103 comments

Feature request

Flash Attention 2 is a library that provides attention operation kernels for faster and more memory efficient inference and training: https://github.com/Dao-AILab/flash-attention

Screenshot 2023-09-22 at 17 49 18

Let's try to add Flash Attention 2 support for more architectures! Currently supported architectures are

  • [x] Llama
  • [x] Falcon

It would be great to add the support for more architectures such as

  • [x] Bark
  • [x] Bart
  • [ ] BERT | @sorenmc
  • [ ] CLIP https://github.com/huggingface/transformers/pull/27444/
  • [x] DistilBERT
  • [x] GPT-2
  • [x] GPT-J
  • [x] GPTBigCode (Starcoder) | @susnato
  • [x] GPT-neo
  • [x] GPT-neo-x | @younesbelkada #26463
  • [x] OPT | @susnato #26414
  • [x] Llava
  • [x] VipLlava
  • [x] mBART
  • [x] Mistral
  • [x] Mixtral
  • [ ] MPT | @rajveer43
  • [ ] T5
  • [ ] Persimmon | @jeromeku
  • [x] Phi
  • [x] Whisper
  • [x] Qwen2

... and many more

Adding this feature would require to follow the same protocol as in https://github.com/huggingface/transformers/pull/25598 . First create a new module inside the corresponding modeling file termed as xxxFlashAttention that inherits from xxxAttention and override the foward method to use the public methods from flash-attn. Make sure to have access to a GPU that supports Flash Attention 2.

Given the slight challenge of the issue, labelling it as a good second issue!

If you are interested to take up the challenge, comment below with the architecture name you want to integrate and open a PR!

Once you open a PR, feel free to ping @LysandreJik @ArthurZucker @amyeroberts @younesbelkada @fxmarty @SunMarc @pacman100 for a review

Motivation

Making LLMs more memory efficient and faster !

Your contribution

Reviewing PRs and possibly adding the support for more models

younesbelkada avatar Sep 22 '23 15:09 younesbelkada

Hi @younesbelkada - I want to work on adding Flash Attention 2 support for GPTBigCode (Starcoder). Can I take this task? Can you please assign this task to me?

sahilbhosale63 avatar Sep 22 '23 17:09 sahilbhosale63

Will definitely take a look next week Great to see it merged now 💪

flozi00 avatar Sep 22 '23 20:09 flozi00

I would like to work on MPT @younesbelkada

rajveer43 avatar Sep 23 '23 10:09 rajveer43

I would like to work on OPT.

susnato avatar Sep 24 '23 17:09 susnato

Is it possible to add FlashAttention2 to GPT2 models?

ZeusFSX avatar Sep 25 '23 09:09 ZeusFSX

@sahilbhosale63 @flozi00 @rajveer43 @susnato thanks very much for your interest! Indeed it would be great if you could help us! Before assigning you to this issue can you confirm you have access to a GPU that does support Flash Attention 2: https://github.com/Dao-AILab/flash-attention#installation-and-features in order to be able to run the tests ? @ZeusFSX , yes I think that it is possible, I'll update the list accodingly

younesbelkada avatar Sep 25 '23 09:09 younesbelkada

@younesbelkada Yes I have

rajveer43 avatar Sep 25 '23 10:09 rajveer43

OK perfect, I will assign you to MPT ! Feel free to let me know if you need any help or if you have any question, as a starting point, I would recommend to have a look at #25598 and see if you can replicate the PR for MPT. For running flash attention tests you can just run (once PR is ready):

RUN_SLOW=1 pytest -m flash_attn_test tests/models/mpt/

younesbelkada avatar Sep 25 '23 12:09 younesbelkada

@younesbelkada yes I have.

susnato avatar Sep 25 '23 12:09 susnato

Thanks @susnato , perfect then, let me know whenever you start the PR and if you have any question ! Check out my instructions above for more details

younesbelkada avatar Sep 25 '23 12:09 younesbelkada

@younesbelkada Unfortunately, My GPU is not supported

sahilbhosale63 avatar Sep 25 '23 13:09 sahilbhosale63

OK perfect, I will assign you to MPT ! Feel free to let me know if you need any help or if you have any question, as a starting point, I would recommend to have a look at #25598 and see if you can replicate the PR for MPT. For running flash attention tests you can just run (once PR is ready):

RUN_SLOW=1 pytest -m flash_attn_tests tests/models/mpt/

Sure I will work on it!

rajveer43 avatar Sep 25 '23 16:09 rajveer43

@younesbelkada Would like to work on Persimmon. I have access to A4000, A5000, and A6000, which I believe should be compatible with FA2.

jeromeku avatar Sep 26 '23 17:09 jeromeku

Perfect sounds great, thanks for your help, I will assign you to Persimmon !

younesbelkada avatar Sep 26 '23 18:09 younesbelkada

Since @sahilbhosale63 is not working on GPTBigCode (Starcoder)(as he said here) can I take that @younesbelkada?

susnato avatar Sep 26 '23 18:09 susnato

Yes no problem, thanks very much for proposing your help on this ! As a starting point you can have a look at @pacman100 's implementation here: https://github.com/pacman100/DHS-LLM-Workshop/blob/main/personal_copilot/training/starcoder_flash_attn_monkey_patch.py

younesbelkada avatar Sep 26 '23 18:09 younesbelkada

@younesbelkada I would like to implement it for BERT if it hasn't already been done? A lot of the models topping MTEB are still relying on this architecture! I have tested that i can run flash attention 2 on my nvidia geforce RTX 3060 TI.

sorenmc avatar Sep 26 '23 19:09 sorenmc

Awesome, thanks a lot for your help, ok I will assign you to BERT then!

younesbelkada avatar Sep 27 '23 08:09 younesbelkada

Hi everyone, I would like to help implement this with GPT2 if you want.

DougTrajano avatar Sep 27 '23 18:09 DougTrajano

@younesbelkada

I have a working version for Persimmon that passes the flash_attn_v2 tests except for generate_padding_right as the original PersimmonFlashAttention does not have padding_mask as a kw input (as opposed to the Llama and Falcon flash implementations). Is this something that needs to be changed in both Persimmon Flash v1 and v2?

Also, any plans on incorporating additional optimizations, e.g., Flash Attention repo has fused layers for dense, rotary, and layer norm for faster training; and Triton kernels, more generally? Happy to investigate more!

Also, would like to help with Mistral-7b (just released). They use xformers memory efficient attention in their released implementation but also mention Tri Dao's FA in the blogpost.

jeromeku avatar Sep 27 '23 20:09 jeromeku

Hi @DougTrajano Awesome! Can you confirm you have access to a hardware that is supported by FA-2?

Screenshot 2023-09-28 at 11 23 36

@jeromeku awesome thanks! Can you move forward for Persimmon by opening a PR so that I can have a look?

Also, any plans on incorporating additional optimizations, e.g., Flash Attention repo has fused layers for dense, rotary, and layer norm for faster training; and Triton kernels, more generally? Happy to investigate more!

If that is something that can nicely fit into the API without any breaking behaviour that would be great !

Also, would like to help with Mistral-7b (just released). They use xformers memory efficient attention in their released implementation but also mention Tri Dao's FA in the blogpost.

I think Mistral's attention has been released in the latest version of FA-2 --> Would you be happy to open a PoC PR so that I can play with it and see what we can do?

Again thanks a lot!

younesbelkada avatar Sep 28 '23 09:09 younesbelkada

Hi @jeromeku I had to check internally for Mistral, given the very recent release and the urgency, we'll take this over (https://github.com/huggingface/transformers/pull/26464); if you have started a PR, I'm very happy to start from it or to add you as a co-author to the PR ! We might also refactor things a bit to support Local attention introduced by Mistral, so that needs further investigation, I'll keep you posted

younesbelkada avatar Sep 28 '23 11:09 younesbelkada

@younesbelkada what is the expected deadline to complete MPT, I have other issues to tackle on so I can plan accordingly

rajveer43 avatar Sep 28 '23 11:09 rajveer43

Hi @younesbelkada , I am talking this up for GPT-neo.

susnato avatar Sep 28 '23 20:09 susnato

Awesome @susnato ! Thanks ! @rajveer43 thanks for taking up MPT, will check it out!

younesbelkada avatar Sep 28 '23 20:09 younesbelkada

Hi @DougTrajano Awesome! Can you confirm you have access to a hardware that is supported by FA-2?

Screenshot 2023-09-28 at 11 23 36

Yes, I'll work on AWS SageMaker.

DougTrajano avatar Sep 28 '23 22:09 DougTrajano

Would love to take on GPT2!

marcasty avatar Sep 28 '23 23:09 marcasty

Thanks for confirming @DougTrajano ! @marcasty thanks a lot for your interest, @DougTrajano has taken up GPT2, would be happy taking another model? 🙏 Can you also confirm you have access to a hardware that support FA-2 ?

younesbelkada avatar Sep 29 '23 06:09 younesbelkada

Hi @younesbelkada, I am taking this up for DistillBERT.

susnato avatar Sep 29 '23 08:09 susnato

@younesbelkada what about T5? I have access to compatible hardware

marcasty avatar Sep 29 '23 11:09 marcasty