transformers icon indicating copy to clipboard operation
transformers copied to clipboard

[RoBERTa-based] Add support for sdpa

Open hackyon opened this issue 1 year ago • 15 comments

What does this PR do?

Adding support for SDPA (scaled dot product attention) for RoBERTa-based models. More context in #28005 and #28802.

Models: camembert, roberta, xlm_roberta, xlm_roberta_xl.

Before submitting

  • [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • [x] Did you read the contributor guideline, Pull Request section?
  • [x] Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
  • [x] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
  • [x] Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.

@fxmarty @ArthurZucker @amyeroberts

hackyon avatar Apr 26 '24 20:04 hackyon

I ran slow tests for the affected models, and verified that they all pass except XLMRobertaXLModelTest::test_eager_matches_sdpa_generate(). I suspect it's just some numerical computation error, but I'll take a quick look to see if I can find anything.

I'll also try to run some the perf benchmarks on RoBERTa over the weekend to see how they behave.

hackyon avatar Apr 26 '24 21:04 hackyon

Preliminary perf numbers for Roberta (using "roberta-base" with AutoModel/Tokenizer).

Training

num_training_steps batch_size seq_len is cuda Time per batch (eager - s) Time per batch (sdpa - s) Speedup (%) Eager peak mem (MB) sdpa peak mem (MB) Mem saving (%)
1000 1 256 True 0.018 0.015 24.411 731.752 736.471 -0.641
1000 1 512 True 0.019 0.016 17.819 823.792 757.096 8.809
1000 2 256 True 0.020 0.016 29.890 760.504 757.096 0.450
1000 2 512 True 0.020 0.016 25.317 1283.793 907.688 41.435
1000 4 256 True 0.020 0.016 28.907 1094.001 907.289 20.579
1000 4 512 True 0.025 0.021 19.153 2205.299 1446.666 52.440

Inference

num_batches batch_size seq_len is cuda is half use mask Per token latency eager (ms) Per token latency SDPA (ms) Speedup (%) Mem eager (MB) Mem BT (MB) Mem saved (%)
50 2 64 True True True 5.357 5.067 5.716 333.956 333.956 0
50 2 128 True True True 5.534 5.181 6.812 360.089 360.089 0
50 2 256 True True True 5.823 5.516 5.577 412.355 412.355 0
50 4 64 True True True 5.632 5.344 5.381 385.611 385.611 0
50 4 128 True True True 6.101 5.849 4.304 437.895 437.877 0.004
50 4 256 True True True 6.91 6.529 5.824 542.598 542.598 0

hackyon avatar Apr 27 '24 00:04 hackyon

It seems like XLMRobertaXLModelTest::test_eager_matches_sdpa_generate() doesn't always fail, but it's flaky and depends on the random number generator. I think it is due to computation/numerical stability, which can result in slightly different results.

EDIT: I added a set_seed(0) to XLMRobertaXLModelTest::test_eager_matches_sdpa_generate(), and the flake seems to have gone away.

hackyon avatar Apr 27 '24 01:04 hackyon

@fxmarty @ArthurZucker @amyeroberts

This is ready for review! With the exception of the changes to the test and check_support_list.py, all the changes are coming from "Copied From". Please let me know if you have any questions!

hackyon avatar Apr 29 '24 17:04 hackyon

@hackyon, I'm curious about whether implementing flash_atten is essential when writing an SDPA. I came across claims that flash_atten can offer up to a x4 efficiency boost (roughly) compared to native PyTorch. However, your remarks in https://github.com/huggingface/transformers/pull/30510 suggest that the actual improvement is less than 50%. Could you help shed some light on this apparent difference?

michaelsheka avatar May 08 '24 21:05 michaelsheka

@michaelshekasta I believe the 4x improvement only applies to certain models, usually larger models with more computationally expensive attention computations.

hackyon avatar May 19 '24 12:05 hackyon

@fxmarty can you have a look and ping me for the final review? 🤗

ArthurZucker avatar May 23 '24 13:05 ArthurZucker

@fxmarty , gentle bump

nbroad1881 avatar Jun 18 '24 01:06 nbroad1881

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@fxmarty what's left? How can I help?

michaelsheka avatar Jun 24 '24 10:06 michaelsheka

@michaelshekasta Approval from @ArthurZucker or @amyeroberts.

fxmarty avatar Jun 24 '24 12:06 fxmarty

@fxmarty , would it be possible to merge this before the release of the next version of transformers?

nbroad1881 avatar Jul 08 '24 16:07 nbroad1881

This PR is super great. I also expect this PR will be merged before the release of the next version.

kiszk avatar Jul 10 '24 01:07 kiszk

@fxmarty you are amazing! If I can help, please write to me

michaelsheka avatar Jul 10 '24 13:07 michaelsheka

@fxmarty Thank you very much. I would appreciate it if you could re-add gpt_neox for consistency. Or can I do it? I am not sure why it was dropped.

https://app.circleci.com/pipelines/github/huggingface/transformers/97500/workflows/4facc164-8c3b-4ad0-9387-be9de636e686/jobs/1291191?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-checks-link&utm_content=summary

Traceback (most recent call last):
  File "/root/transformers/utils/check_support_list.py", line 97, in <module>
    check_sdpa_support_list()
  File "/root/transformers/utils/check_support_list.py", line 90, in check_sdpa_support_list
    raise ValueError(
ValueError: gpt_neox should be in listed in the SDPA documentation but is not. Please update the documentation.

Exited with code exit status 1

kiszk avatar Jul 10 '24 18:07 kiszk

Thanks @kiszk, missed it when reordering the lists.

fxmarty avatar Jul 11 '24 09:07 fxmarty

gentle ping @ArthurZucker @amyeroberts

fxmarty avatar Jul 12 '24 16:07 fxmarty

@ArthurZucker @amyeroberts

fxmarty avatar Jul 16 '24 09:07 fxmarty

@fxmarty You may want to resolve conflicts.

kiszk avatar Jul 22 '24 06:07 kiszk

Sorry did not have time before, will try to do today or next week. It's a big PR with lots of changes, need to be extra careful!

ArthurZucker avatar Jul 26 '24 10:07 ArthurZucker

@ArthurZucker would you have a time for this review?

kiszk avatar Aug 13 '24 02:08 kiszk

I've also experienced approximately 20% faster training with XLMRoberta using this PR on an RTX4090. I've been testing it for over a week now, and it's been working without any issues. I sincerely hope this can be merged.

hotchpotch avatar Aug 15 '24 19:08 hotchpotch

@ArthurZucker Can we help with anything reviewing this PR?

kiszk avatar Aug 27 '24 15:08 kiszk

@ArthurZucker when do you think that this change will appear in transformers package? next version?

P.S. You are so amazing guys!

michaelsheka avatar Aug 28 '24 09:08 michaelsheka

It should be there in at most 2 weeks! 🤗

ArthurZucker avatar Aug 28 '24 09:08 ArthurZucker

I would like to thank everyone involved in this Pull Request from the bottom of my heart! 🎉

hotchpotch avatar Aug 28 '24 21:08 hotchpotch

@ArthurZucker A gentle reminder ;-)

It should be there in at most 2 weeks! 🤗

@ArthurZucker A gentle remider ;-)

michaelsheka avatar Sep 10 '24 20:09 michaelsheka

We are gonna release today / tomorrow! 🤗 sorry for the delay

ArthurZucker avatar Sep 17 '24 00:09 ArthurZucker

@ArthurZucker Thanks!!! I hope that it will be release today!!

michaelsheka avatar Sep 19 '24 13:09 michaelsheka

😐 really sorry, big big release is coming on Wednesday, don't the wait is worth it ! 👀

ArthurZucker avatar Sep 21 '24 02:09 ArthurZucker