transformers
transformers copied to clipboard
[RoBERTa-based] Add support for sdpa
What does this PR do?
Adding support for SDPA (scaled dot product attention) for RoBERTa-based models. More context in #28005 and #28802.
Models: camembert, roberta, xlm_roberta, xlm_roberta_xl.
Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [x] Did you read the contributor guideline, Pull Request section?
- [x] Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
- [x] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
- [x] Did you write any new necessary tests?
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.
@fxmarty @ArthurZucker @amyeroberts
I ran slow tests for the affected models, and verified that they all pass except XLMRobertaXLModelTest::test_eager_matches_sdpa_generate(). I suspect it's just some numerical computation error, but I'll take a quick look to see if I can find anything.
I'll also try to run some the perf benchmarks on RoBERTa over the weekend to see how they behave.
Preliminary perf numbers for Roberta (using "roberta-base" with AutoModel/Tokenizer).
Training
| num_training_steps | batch_size | seq_len | is cuda | Time per batch (eager - s) | Time per batch (sdpa - s) | Speedup (%) | Eager peak mem (MB) | sdpa peak mem (MB) | Mem saving (%) |
|---|---|---|---|---|---|---|---|---|---|
| 1000 | 1 | 256 | True | 0.018 | 0.015 | 24.411 | 731.752 | 736.471 | -0.641 |
| 1000 | 1 | 512 | True | 0.019 | 0.016 | 17.819 | 823.792 | 757.096 | 8.809 |
| 1000 | 2 | 256 | True | 0.020 | 0.016 | 29.890 | 760.504 | 757.096 | 0.450 |
| 1000 | 2 | 512 | True | 0.020 | 0.016 | 25.317 | 1283.793 | 907.688 | 41.435 |
| 1000 | 4 | 256 | True | 0.020 | 0.016 | 28.907 | 1094.001 | 907.289 | 20.579 |
| 1000 | 4 | 512 | True | 0.025 | 0.021 | 19.153 | 2205.299 | 1446.666 | 52.440 |
Inference
| num_batches | batch_size | seq_len | is cuda | is half | use mask | Per token latency eager (ms) | Per token latency SDPA (ms) | Speedup (%) | Mem eager (MB) | Mem BT (MB) | Mem saved (%) |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 50 | 2 | 64 | True | True | True | 5.357 | 5.067 | 5.716 | 333.956 | 333.956 | 0 |
| 50 | 2 | 128 | True | True | True | 5.534 | 5.181 | 6.812 | 360.089 | 360.089 | 0 |
| 50 | 2 | 256 | True | True | True | 5.823 | 5.516 | 5.577 | 412.355 | 412.355 | 0 |
| 50 | 4 | 64 | True | True | True | 5.632 | 5.344 | 5.381 | 385.611 | 385.611 | 0 |
| 50 | 4 | 128 | True | True | True | 6.101 | 5.849 | 4.304 | 437.895 | 437.877 | 0.004 |
| 50 | 4 | 256 | True | True | True | 6.91 | 6.529 | 5.824 | 542.598 | 542.598 | 0 |
It seems like XLMRobertaXLModelTest::test_eager_matches_sdpa_generate() doesn't always fail, but it's flaky and depends on the random number generator. I think it is due to computation/numerical stability, which can result in slightly different results.
EDIT: I added a set_seed(0) to XLMRobertaXLModelTest::test_eager_matches_sdpa_generate(), and the flake seems to have gone away.
@fxmarty @ArthurZucker @amyeroberts
This is ready for review! With the exception of the changes to the test and check_support_list.py, all the changes are coming from "Copied From". Please let me know if you have any questions!
@hackyon, I'm curious about whether implementing flash_atten is essential when writing an SDPA. I came across claims that flash_atten can offer up to a x4 efficiency boost (roughly) compared to native PyTorch. However, your remarks in https://github.com/huggingface/transformers/pull/30510 suggest that the actual improvement is less than 50%. Could you help shed some light on this apparent difference?
@michaelshekasta I believe the 4x improvement only applies to certain models, usually larger models with more computationally expensive attention computations.
@fxmarty can you have a look and ping me for the final review? 🤗
@fxmarty , gentle bump
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
@fxmarty what's left? How can I help?
@michaelshekasta Approval from @ArthurZucker or @amyeroberts.
@fxmarty , would it be possible to merge this before the release of the next version of transformers?
This PR is super great. I also expect this PR will be merged before the release of the next version.
@fxmarty you are amazing! If I can help, please write to me
@fxmarty Thank you very much. I would appreciate it if you could re-add gpt_neox for consistency. Or can I do it? I am not sure why it was dropped.
https://app.circleci.com/pipelines/github/huggingface/transformers/97500/workflows/4facc164-8c3b-4ad0-9387-be9de636e686/jobs/1291191?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-checks-link&utm_content=summary
Traceback (most recent call last):
File "/root/transformers/utils/check_support_list.py", line 97, in <module>
check_sdpa_support_list()
File "/root/transformers/utils/check_support_list.py", line 90, in check_sdpa_support_list
raise ValueError(
ValueError: gpt_neox should be in listed in the SDPA documentation but is not. Please update the documentation.
Exited with code exit status 1
Thanks @kiszk, missed it when reordering the lists.
gentle ping @ArthurZucker @amyeroberts
@ArthurZucker @amyeroberts
@fxmarty You may want to resolve conflicts.
Sorry did not have time before, will try to do today or next week. It's a big PR with lots of changes, need to be extra careful!
@ArthurZucker would you have a time for this review?
I've also experienced approximately 20% faster training with XLMRoberta using this PR on an RTX4090. I've been testing it for over a week now, and it's been working without any issues. I sincerely hope this can be merged.
@ArthurZucker Can we help with anything reviewing this PR?
@ArthurZucker when do you think that this change will appear in transformers package? next version?
P.S. You are so amazing guys!
It should be there in at most 2 weeks! 🤗
I would like to thank everyone involved in this Pull Request from the bottom of my heart! 🎉
@ArthurZucker A gentle reminder ;-)
It should be there in at most 2 weeks! 🤗
@ArthurZucker A gentle remider ;-)
We are gonna release today / tomorrow! 🤗 sorry for the delay
@ArthurZucker Thanks!!! I hope that it will be release today!!
😐 really sorry, big big release is coming on Wednesday, don't the wait is worth it ! 👀