transformers icon indicating copy to clipboard operation
transformers copied to clipboard

[`BERT`] Add support for sdpa

Open hackyon opened this issue 1 year ago • 16 comments

What does this PR do?

Adding support for SDPA (scaled dot product attention) for Bert. More context in #28005.

Benchmarking Results on A100-80GB, CPUx12, RAM 96.6GB, OS Ubuntu 22.04, using BertLMHeadModel

Training benchmark based on fxmarty's script:

num_training_steps batch_size seq_len Time per batch (eager - s) Time per batch (sdpa - s) Speedup (%) Eager peak mem (MB) sdpa peak mem (MB) Mem saving (%)
1000 1 256 0.022 0.018 23.905 1128.190 1065.286 5.905
1000 1 512 0.034 0.028 20.473 1345.791 1093.933 23.023
1000 2 256 0.031 0.026 18.701 1175.685 1093.933 7.473
1000 2 512 0.057 0.047 21.315 2123.874 1370.097 55.016
1000 4 256 0.052 0.044 16.446 1784.135 1369.489 30.277
1000 4 512 0.106 0.087 21.524 3706.609 2196.791 68.728

Inference benchmark based on fxmarty's script:

num_batches batch_size seq_len Per token latency eager (ms) Per token latency SDPA (ms) Speedup (%) Mem eager (MB) Mem BT (MB) Mem saved (%)
50 1 64 5.906 5.420 8.962 271.610 271.407 0.075
50 1 128 5.825 5.402 7.834 279.157 279.718 -0.200
50 2 64 6.190 5.349 15.709 291.489 291.751 -0.090
50 2 128 6.168 5.360 15.066 307.514 307.776 -0.085
50 4 64 6.262 5.392 16.137 332.177 332.440 -0.079
50 4 128 6.201 5.382 15.215 364.271 364.742 -0.129

Before submitting

  • [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • [x] Did you read the contributor guideline, Pull Request section?
  • [x] Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
  • [x] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
  • [x] Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.

@ArthurZucker @younesbelkada

(cc @fxmarty)

hackyon avatar Jan 31 '24 22:01 hackyon

Hey @ArthurZucker @younesbelkada

I was thinking SDPA (#28005) could be a good addition to BERT, so I drafted this change. It doesn't look too hairy so far.

As @ArthurZucker mentioned, BERT doesn't have a lot of params so there might not be much of a speedup, but this didn't look too difficult to implement so I figured whatever little improvement might still be helpful (as an aside, there's been some benchmarking of Flash Attention on training other implementations of BERT, and it still shows decent improvements).

Can you let me know if this is worth pursuing? If so, I'll add the tests and also fix the fix-copies dependencies.

Thanks!

hackyon avatar Jan 31 '24 23:01 hackyon

I think a good way to se if it is worth the shot is to benchmark your code and check if you have speedups in different contexts!

ArthurZucker avatar Feb 01 '24 13:02 ArthurZucker

Sounds good, lemme look into that

hackyon avatar Feb 01 '24 16:02 hackyon

@ArthurZucker I did some training and inference benchmarking for my change and posted the results in the PR description.

It looks like there are decent improvements across the board (percentage-wise, but I think the improvements would add up if we're doing a lot of training/inferencing). I think it could be a good addition. Thoughts?

hackyon avatar Feb 06 '24 07:02 hackyon

Sounds like a good addition then! I'll let @fxmarty review and will be doing the final pass!

ArthurZucker avatar Feb 07 '24 07:02 ArthurZucker

Just curious, is it similar to https://github.com/huggingface/transformers/pull/27478 ? Seems also https://github.com/huggingface/transformers/pull/28713 is highly related.

pommedeterresautee avatar Feb 07 '24 16:02 pommedeterresautee

re: @pommedeterresautee

Yes, it's similar. SDPA is built into pytorch, and can support Flash Attention (1) depending on the environment. AFAIK Flash Attention 2 isn't supported in SDPA yet, but there is a possibility for it to be supported down the road (but that should be built into pytorch already, and shouldn't need many changes from our end).

hackyon avatar Feb 07 '24 16:02 hackyon

Thanks, I think it is now https://pytorch.org/blog/pytorch2-2/ scaled_dot_product_attention (SDPA) now supports FlashAttention-2, yielding around 2x speedups compared to previous versions.

pommedeterresautee avatar Feb 07 '24 16:02 pommedeterresautee

Oh nice, so I guess we could get FA2 for free eventually (when we upgrade pytorch).

Thanks for the links to similar work. I think they could cause some merge conflicts, so I'll message them and try to resolve it before it goes in.

hackyon avatar Feb 07 '24 16:02 hackyon

I've rebased off of head and marked as ready for review. I had to dig through a couple of issues to get the tests to pass, let me now if you want to chat about any of them.

Thanks!

hackyon avatar Feb 08 '24 15:02 hackyon

@fxmarty @hackyon There's still several tests failing related to this PR. Once these are resolved you can ping me again for a final review

amyeroberts avatar Feb 08 '24 16:02 amyeroberts

The tests are passing now. I also verified that test_modeling_bert passes with RUN_SLOW=1 (which contains the tests to ensure model output is the same for eager and sdpa attentions)

Please take another look when you get a chance. Thanks!

hackyon avatar Feb 08 '24 21:02 hackyon

Thanks for reviewing!

Some general comments:

Yup. I'll merge this PR to HEAD to get rid of the diffs once that other PR goes in.

  • It would be good to add the performance numbers in the PR description to BERT's model page, similar to what's done for Flash Attention e.g. [here](https://huggingface.co/docs/transformers/v4.37.2/en/model_doc/gpt_neox#using-flash-attention-2.

I'll look into it.

  • test_eager_matches_sdpa_inference should be run for all existing models with SDPA implemented to confirm compatibility with the change in processed_inputs

This one is tricky. Locally, this method is already failing for some of the models on main/HEAD without my change (such as for FalconModelTest and Qwen2ModelTest). Any chance you can try to run this test on main/HEAD and see if you are seeing those failures on your machine as well?

  • We shouldn't be setting self._use_sdpa that don't have an SDPA attention class. We can just about get away with it for the models which have an attention dict, but not for the other models.

I've removed them.

hackyon avatar Feb 14 '24 18:02 hackyon

I added some documentation on SPDA to the BERT model page.

For the inference tests, I am seeing the same failures in FalconModelTest and Qwen2ModelTest with and without (ie. main/HEAD) my change. They should be unrelated to my changes.

I think the Falcon failure is likely just an edge case problem (for some reason the difference is a little higher in this one case), whereas the Qwen2 failure is likely due to an incorrect SDPA implementation.

hackyon avatar Feb 15 '24 06:02 hackyon

@hackyon can you update https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention as well?

fxmarty avatar Feb 16 '24 09:02 fxmarty

@hackyon can you update https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention as well?

Updated! Thanks for reviewing again.

We're still waiting on #29024 to go in for the tests here to pass, but otherwise, the code here should be more or less complete.

hackyon avatar Feb 16 '24 15:02 hackyon

LGTM, great work!

Thanks!

One thing you could do to be sure: CUDA_VISIBLE_DEVICES=0 pytest tests/models/bert -s -vvvvv

Ran this, and there was an unrelated test failure, but otherwise everything else passes.

FAILED tests/models/bert/test_tokenization_bert.py::BertTokenizationTest::test_saving_tokenizer_trainer - TypeError: Accelerator.__init__() got an unexpected keyword argument 'use_seedable_sampler'
====================================================== 1 failed, 294 passed, 96 skipped, 409 warnings in 252.55s (0:04:12) =======================================================

hackyon avatar Feb 19 '24 13:02 hackyon

Not super fan of the complexity of _prepare_4d_causal_attention_mask_for_sdpa, and we should not add it in our new code IMO.

Not as simple as that, we want to drop the attention_mask for some cases otherwise FA2 is never used with SDPA (as with llama right now). But _prepare_4d_causal_attention_mask_for_sdpa should indeed be rewritten, IMO unrelated to this PR

fxmarty avatar Feb 27 '24 11:02 fxmarty

Thanks for the review @ArthurZucker! 🙏 Sorry it took me longer to respond this time, have been busy this past week.

As @fxmarty mentioned, _prepare_4d_causal_attention_mask_for_sdpa is used pretty widely in all the sdpa implementations due to the nuances around the attention mask (for example, it's in whisper, bart, mistral, and pretty much all the other models AFAICT). Are you cool with leaving this in for now, and coming back to refactor them later once the nuances are figured out?

hackyon avatar Mar 06 '24 04:03 hackyon

Of course!

ArthurZucker avatar Mar 07 '24 05:03 ArthurZucker

I've updated based on comments. Please take another look, thanks!

hackyon avatar Mar 07 '24 14:03 hackyon

@hackyon could you merge/rebase on main?

fxmarty avatar Mar 19 '24 06:03 fxmarty

Sure, I just merged with main/HEAD. @amyeroberts @ArthurZucker do you mind taking a look?

I'm having trouble starting my cloud server right now due to high demand, but I'll run it through the slow tests later on when it works again.

hackyon avatar Mar 19 '24 15:03 hackyon

Hey! Sure I was off for a bit but will have a look

ArthurZucker avatar Mar 25 '24 08:03 ArthurZucker

Oh wow

ArthurZucker avatar Apr 05 '24 07:04 ArthurZucker

Thanks!

I merged with main/HEAD, and re-ran the RUN_SLOW tests for both bert and also for test_eager_matches_sdpa_inference and they work as expected. There were existing failures for test_eager_matches_sdpa_inference with RUN_SLOW on main/HEAD, but nothing new introduced by this change.

I'm not sure about this test_pipelines_tf failure. I haven't touched any code with tf, and I was able to get the failing test test_stop_sequence_stopping_criteria to pass locally, so I'm thinking it's a flake or unrelated to this change.

hackyon avatar Apr 06 '24 22:04 hackyon

Hi @hackyon - great to see this ready to merge!

The generation tests aren't related to this diff and are failing on other PRs. We're working to push a fix to main - will let you know when resolved, you can rebase and hopefully we have full 🟢 for merging 🤗

amyeroberts avatar Apr 08 '24 08:04 amyeroberts

Thanks @amyeroberts @ArthurZucker

Just remerged with main/HEAD, and the unrelated failing TF pipeline test now passes. I checked the bert tests again with RUN_SLOW for good measure, and they continue to pass.

Let me know if there's anything else I could do here. Thanks!

hackyon avatar Apr 11 '24 19:04 hackyon

@ArthurZucker Please let me know if there's anything else you'd like me to do for this PR. Thanks!

hackyon avatar Apr 15 '24 15:04 hackyon

Remerged with the latest main, and fixed a test.

@ArthurZucker @amyeroberts @fxmarty Please let me know if there's anything I can do here.

hackyon avatar Apr 22 '24 13:04 hackyon