transformers
transformers copied to clipboard
[`BERT`] Add support for sdpa
What does this PR do?
Adding support for SDPA (scaled dot product attention) for Bert. More context in #28005.
Benchmarking Results on A100-80GB, CPUx12, RAM 96.6GB, OS Ubuntu 22.04, using BertLMHeadModel
Training benchmark based on fxmarty's script:
num_training_steps | batch_size | seq_len | Time per batch (eager - s) | Time per batch (sdpa - s) | Speedup (%) | Eager peak mem (MB) | sdpa peak mem (MB) | Mem saving (%) |
---|---|---|---|---|---|---|---|---|
1000 | 1 | 256 | 0.022 | 0.018 | 23.905 | 1128.190 | 1065.286 | 5.905 |
1000 | 1 | 512 | 0.034 | 0.028 | 20.473 | 1345.791 | 1093.933 | 23.023 |
1000 | 2 | 256 | 0.031 | 0.026 | 18.701 | 1175.685 | 1093.933 | 7.473 |
1000 | 2 | 512 | 0.057 | 0.047 | 21.315 | 2123.874 | 1370.097 | 55.016 |
1000 | 4 | 256 | 0.052 | 0.044 | 16.446 | 1784.135 | 1369.489 | 30.277 |
1000 | 4 | 512 | 0.106 | 0.087 | 21.524 | 3706.609 | 2196.791 | 68.728 |
Inference benchmark based on fxmarty's script:
num_batches | batch_size | seq_len | Per token latency eager (ms) | Per token latency SDPA (ms) | Speedup (%) | Mem eager (MB) | Mem BT (MB) | Mem saved (%) |
---|---|---|---|---|---|---|---|---|
50 | 1 | 64 | 5.906 | 5.420 | 8.962 | 271.610 | 271.407 | 0.075 |
50 | 1 | 128 | 5.825 | 5.402 | 7.834 | 279.157 | 279.718 | -0.200 |
50 | 2 | 64 | 6.190 | 5.349 | 15.709 | 291.489 | 291.751 | -0.090 |
50 | 2 | 128 | 6.168 | 5.360 | 15.066 | 307.514 | 307.776 | -0.085 |
50 | 4 | 64 | 6.262 | 5.392 | 16.137 | 332.177 | 332.440 | -0.079 |
50 | 4 | 128 | 6.201 | 5.382 | 15.215 | 364.271 | 364.742 | -0.129 |
Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [x] Did you read the contributor guideline, Pull Request section?
- [x] Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
- [x] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
- [x] Did you write any new necessary tests?
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.
@ArthurZucker @younesbelkada
(cc @fxmarty)
Hey @ArthurZucker @younesbelkada
I was thinking SDPA (#28005) could be a good addition to BERT, so I drafted this change. It doesn't look too hairy so far.
As @ArthurZucker mentioned, BERT doesn't have a lot of params so there might not be much of a speedup, but this didn't look too difficult to implement so I figured whatever little improvement might still be helpful (as an aside, there's been some benchmarking of Flash Attention on training other implementations of BERT, and it still shows decent improvements).
Can you let me know if this is worth pursuing? If so, I'll add the tests and also fix the fix-copies dependencies.
Thanks!
I think a good way to se if it is worth the shot is to benchmark your code and check if you have speedups in different contexts!
Sounds good, lemme look into that
@ArthurZucker I did some training and inference benchmarking for my change and posted the results in the PR description.
It looks like there are decent improvements across the board (percentage-wise, but I think the improvements would add up if we're doing a lot of training/inferencing). I think it could be a good addition. Thoughts?
Sounds like a good addition then! I'll let @fxmarty review and will be doing the final pass!
Just curious, is it similar to https://github.com/huggingface/transformers/pull/27478 ? Seems also https://github.com/huggingface/transformers/pull/28713 is highly related.
re: @pommedeterresautee
Yes, it's similar. SDPA is built into pytorch, and can support Flash Attention (1) depending on the environment. AFAIK Flash Attention 2 isn't supported in SDPA yet, but there is a possibility for it to be supported down the road (but that should be built into pytorch already, and shouldn't need many changes from our end).
Thanks, I think it is now https://pytorch.org/blog/pytorch2-2/ scaled_dot_product_attention (SDPA) now supports FlashAttention-2, yielding around 2x speedups compared to previous versions.
Oh nice, so I guess we could get FA2 for free eventually (when we upgrade pytorch).
Thanks for the links to similar work. I think they could cause some merge conflicts, so I'll message them and try to resolve it before it goes in.
I've rebased off of head and marked as ready for review. I had to dig through a couple of issues to get the tests to pass, let me now if you want to chat about any of them.
Thanks!
@fxmarty @hackyon There's still several tests failing related to this PR. Once these are resolved you can ping me again for a final review
The tests are passing now. I also verified that test_modeling_bert passes with RUN_SLOW=1 (which contains the tests to ensure model output is the same for eager and sdpa attentions)
Please take another look when you get a chance. Thanks!
Thanks for reviewing!
Some general comments:
- Let's wait for the merging of Add tie_weights() to LM heads and set bias in set_output_embeddings() #28948 before merging this in
Yup. I'll merge this PR to HEAD to get rid of the diffs once that other PR goes in.
- It would be good to add the performance numbers in the PR description to BERT's model page, similar to what's done for Flash Attention e.g. [here](https://huggingface.co/docs/transformers/v4.37.2/en/model_doc/gpt_neox#using-flash-attention-2.
I'll look into it.
test_eager_matches_sdpa_inference
should be run for all existing models with SDPA implemented to confirm compatibility with the change inprocessed_inputs
This one is tricky. Locally, this method is already failing for some of the models on main/HEAD without my change (such as for FalconModelTest and Qwen2ModelTest). Any chance you can try to run this test on main/HEAD and see if you are seeing those failures on your machine as well?
- We shouldn't be setting
self._use_sdpa
that don't have an SDPA attention class. We can just about get away with it for the models which have an attention dict, but not for the other models.
I've removed them.
I added some documentation on SPDA to the BERT model page.
For the inference tests, I am seeing the same failures in FalconModelTest and Qwen2ModelTest with and without (ie. main/HEAD) my change. They should be unrelated to my changes.
I think the Falcon failure is likely just an edge case problem (for some reason the difference is a little higher in this one case), whereas the Qwen2 failure is likely due to an incorrect SDPA implementation.
@hackyon can you update https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention as well?
@hackyon can you update https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention as well?
Updated! Thanks for reviewing again.
We're still waiting on #29024 to go in for the tests here to pass, but otherwise, the code here should be more or less complete.
LGTM, great work!
Thanks!
One thing you could do to be sure:
CUDA_VISIBLE_DEVICES=0 pytest tests/models/bert -s -vvvvv
Ran this, and there was an unrelated test failure, but otherwise everything else passes.
FAILED tests/models/bert/test_tokenization_bert.py::BertTokenizationTest::test_saving_tokenizer_trainer - TypeError: Accelerator.__init__() got an unexpected keyword argument 'use_seedable_sampler'
====================================================== 1 failed, 294 passed, 96 skipped, 409 warnings in 252.55s (0:04:12) =======================================================
Not super fan of the complexity of _prepare_4d_causal_attention_mask_for_sdpa, and we should not add it in our new code IMO.
Not as simple as that, we want to drop the attention_mask for some cases otherwise FA2 is never used with SDPA (as with llama right now). But _prepare_4d_causal_attention_mask_for_sdpa
should indeed be rewritten, IMO unrelated to this PR
Thanks for the review @ArthurZucker! 🙏 Sorry it took me longer to respond this time, have been busy this past week.
As @fxmarty mentioned, _prepare_4d_causal_attention_mask_for_sdpa
is used pretty widely in all the sdpa implementations due to the nuances around the attention mask (for example, it's in whisper, bart, mistral, and pretty much all the other models AFAICT). Are you cool with leaving this in for now, and coming back to refactor them later once the nuances are figured out?
Of course!
I've updated based on comments. Please take another look, thanks!
@hackyon could you merge/rebase on main?
Sure, I just merged with main/HEAD. @amyeroberts @ArthurZucker do you mind taking a look?
I'm having trouble starting my cloud server right now due to high demand, but I'll run it through the slow tests later on when it works again.
Hey! Sure I was off for a bit but will have a look
Oh wow
Thanks!
I merged with main/HEAD, and re-ran the RUN_SLOW tests for both bert and also for test_eager_matches_sdpa_inference and they work as expected. There were existing failures for test_eager_matches_sdpa_inference with RUN_SLOW on main/HEAD, but nothing new introduced by this change.
I'm not sure about this test_pipelines_tf failure. I haven't touched any code with tf, and I was able to get the failing test test_stop_sequence_stopping_criteria to pass locally, so I'm thinking it's a flake or unrelated to this change.
Hi @hackyon - great to see this ready to merge!
The generation tests aren't related to this diff and are failing on other PRs. We're working to push a fix to main - will let you know when resolved, you can rebase and hopefully we have full 🟢 for merging 🤗
Thanks @amyeroberts @ArthurZucker
Just remerged with main/HEAD, and the unrelated failing TF pipeline test now passes. I checked the bert tests again with RUN_SLOW for good measure, and they continue to pass.
Let me know if there's anything else I could do here. Thanks!
@ArthurZucker Please let me know if there's anything else you'd like me to do for this PR. Thanks!
Remerged with the latest main, and fixed a test.
@ArthurZucker @amyeroberts @fxmarty Please let me know if there's anything I can do here.