transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Open to contribution: adding `torch.nn.functional.scaled_dot_product_attention` support for more architectures

Open fxmarty opened this issue 1 year ago • 10 comments

Feature request

In Transformers 4.36, we started adding native support of torch.nn.functional.scaled_dot_product_attention (SDPA), enabled by default in Transformers: https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-and-memory-efficient-attention-through-pytorchs-scaleddotproductattention

SDPA allows to dispatch to memory-efficient attention, flash attention on supported GPUs (currently NVIDIA-only), and even on Intel CPUs.

For the record, here's a benchmark on some currently supported models:

Training benchmark, run on A100-SXM4-80GB.

Model Batch size Sequence length Time per batch ("eager", s) Time per batch ("sdpa", s) Speedup Peak memory ("eager", MB) Peak memory ("sdpa", MB) Memory savings
llama2 7b 4 1024 1.065 0.90 19.4% 73878.28 45977.81 60.7%
llama2 7b 4 2048 OOM 1.87 / OOM 78394.58 SDPA does not OOM
llama2 7b 1 2048 0.64 0.48 32.0% 55557.01 29795.63 86.4%
llama2 7b 1 3072 OOM 0.75 / OOM 37916.08 SDPA does not OOM
llama2 7b 1 4096 OOM 1.03 / OOM 46028.14 SDPA does not OOM
llama2 7b 2 4096 OOM 2.05 / OOM 78428.14 SDPA does not OOM

Inference benchmark, run on A100-SXM4-80GB.

Model Batch size Prompt length Num new tokens Per token latency "eager" (ms) Per token latency "sdpa" (ms) Speedup
llama2 13b 1 1024 1 (prefill) 178.66 159.36 12.11%
llama2 13b 1 100 100 40.35 37.62 7.28%
llama2 13b 8 100 100 40.55 38.06 6.53%
Whisper v3 large 1 / 62 20.05 18.90 6.10%
Whisper v3 large 8 / 77 25.42 24.77 2.59%
Whisper v3 large 16 / 77 28.51 26.32 8.34%

Previously, we had a partial support of SDPA in Optimum BetterTransformer but we are now looking to slowly deprecate it in favor of upstream support of SDPA directly in Transformers.

Here are the architectures for which support has been requested:

  • [ ] Codegen (https://github.com/huggingface/optimum/issues/1050)
  • [ ] LLAVA (https://github.com/huggingface/optimum/issues/1592)
  • [ ] Marian (https://github.com/huggingface/optimum/issues/1142)
  • [x] Mistral (https://github.com/huggingface/optimum/issues/1553)
  • [ ] LongT5 (https://github.com/huggingface/optimum/issues/1506)
  • [ ] ViT (https://github.com/huggingface/optimum/issues/1553)

The integration could take inspiration from https://github.com/huggingface/optimum/blob/main/optimum/bettertransformer/models/decoder_models.py & https://github.com/huggingface/optimum/blob/main/optimum/bettertransformer/models/attention.py

Motivation

Faster training & inference, lower memory requirement

Your contribution

I may work on some at some point, but contributions are most welcome.

You should refer to https://github.com/huggingface/transformers/pull/26572 to add the support of SDPA for a model, roughly following these steps:

  • Create a XxxSdpaAttention class inheriting from XxxAttention and implement the attention logic using SDPA
  • Use _prepare_4d_causal_attention_mask_for_sdpa instead of _prepare_4d_causal_attention_mask for SDPA
  • Use _prepare_4d_attention_mask_for_sdpa instead of _prepare_4d_attention_mask for SDPA
  • Add _supports_sdpa = True to XxxPreTrainedModel
  • Add "sdpa" key to XXX_ATTENTION_CLASSES in the model modeling file

fxmarty avatar Dec 13 '23 12:12 fxmarty

Hi @fxmarty I can take a look at this issue. Of I can ask questions if necessary. Or has anyone taken it already?

ENate avatar Dec 14 '23 13:12 ENate

does someone know if longT5 and all T5 models are blocked by bias support in flash attention ?

https://github.com/Dao-AILab/flash-attention/pull/617

davidan5 avatar Dec 18 '23 19:12 davidan5

Hi @davidan5 are you working on the implementation?

ENate avatar Dec 19 '23 08:12 ENate

@ENate I was trying to understand the status and have an estimation of the code change to see if I can contribute.

davidan5 avatar Dec 19 '23 12:12 davidan5

I see.

ENate avatar Dec 19 '23 13:12 ENate

I'm interested in taking a look at this for the Mistral model if that's still needed. Otherwise, please let me know if there are any other models that still need some work. Thanks

hackyon avatar Jan 29 '24 22:01 hackyon

Is LongT5 still open?

ENate avatar Jan 29 '24 23:01 ENate

Mistral is already covered! LongT5 if it is like T5 and has attention bias that might not be supported

ArthurZucker avatar Jan 30 '24 09:01 ArthurZucker

Oh yea, looks like you added support for Mistral/Mixtral last month.

It doesn't seem to be supported for BERT yet (I think someone else is working on FA2 but not SDPA), so I'll take a crack at it. It looks like there is a config for relative position embeddings for BERT, so I'll just have it fallback to the original attention for configs using relative position embeddings.

@ArthurZucker - Please let me know if you know if someone else is already working on SDPA for BERT and I can look for something else to do. Thanks!

hackyon avatar Jan 30 '24 22:01 hackyon

Not sure anyone is working on that but bert is already so small that I doubt it will have a lot of impact on perf!

ArthurZucker avatar Jan 31 '24 01:01 ArthurZucker

@ArthurZucker for the T5 family of models, attention bias is required, so flash-attention won't work for now but torch SDPA can still use the memory efficient kernel from xformers, right? I did some benchmarking with Chronos models (based on T5 architecture) here (https://github.com/amazon-science/chronos-forecasting/issues/33) and there's a clear speedup when using torch SDPA.

abdulfatir avatar Mar 31 '24 21:03 abdulfatir

@abdulfatir That's correct

fxmarty avatar Apr 02 '24 09:04 fxmarty

I can open a PR for T5 with SDPA then. Are there specific things that I should know of or a reference that can look at?

abdulfatir avatar Apr 02 '24 09:04 abdulfatir

@abdulfatir For sure, some specific things that are good to know:

https://github.com/pytorch/pytorch/issues/108108 (is_causal=True may not do what you expect) https://github.com/pytorch/pytorch/issues/110213 (You need https://github.com/huggingface/transformers/blob/416711c3ea88109cf25a9c5f85b4aeee2cb831b5/src/transformers/modeling_attn_mask_utils.py#L189)

example of a PR: https://github.com/huggingface/transformers/pull/29108

fxmarty avatar Apr 02 '24 09:04 fxmarty

FYI going forward we should rather use https://github.com/huggingface/transformers/blob/416711c3ea88109cf25a9c5f85b4aeee2cb831b5/src/transformers/models/llama/modeling_llama.py#L1058 as it is more self contained, easier to debug and maintain than the many paths in the atnn_mask utils

ArthurZucker avatar Apr 02 '24 09:04 ArthurZucker

Hey @abdulfatir just wanted to check in if you are still working on dropping a PR to add SDPA support for T5? It would tremendously help accelerating diffusion models that use T5.

sayakpaul avatar Apr 19 '24 10:04 sayakpaul

@sayakpaul sorry, I was on vacation. Will look into this now and maybe open a PR in a couple of days. I didn't know that there were diffusion models using the T5 arch. Pretty cool!

abdulfatir avatar Apr 19 '24 11:04 abdulfatir

Amongst the open ones that are available, the most notable ones are:

  • PixArt-Alpha
  • PixArt-Sigma
  • SD3 (will soon be open)
  • DeepFloyd IF

So, it will be huge :)

sayakpaul avatar Apr 19 '24 12:04 sayakpaul

@fxmarty @ArthurZucker @sayakpaul I have opened a PR #30375 for T5. I still have a couple of questions due to some tests failing. Let's discuss those on the PR.

abdulfatir avatar Apr 21 '24 17:04 abdulfatir

Out of interest, is there any merit in implementing SDPA/FA2 for the DeBERTa family of models? Parametrically, the models are quite small but have relatively high memory costs due to (among other things) the disentangled attention mechanism - however I'm uncertain whether SDPA/FA2 would have a material effect on these costs.

There was a PR to refactor the DeBERTa modelling scripts (#22105), which appears to have been set aside for now, and I am unsure whether any potential implementations for DeBERTa were gated by that or waiting for it to avoid clashes. There are two tangentially related PRs (#27734, relating to DeBERTa and #28802, relating to BERT, but potentially usable as a starting point) but both appear to have been stuck awaiting some final review or merging for a month or two now.

I think there's still significant utility in the DeBERTa and other BERT family models given their lightweight size and still-leading functionality for text classification, although of course the limelight is held by larger text-completion/multi-modal models, especially with the rapid release of Phi-3, Llama-3 etc. over the past few weeks.

While most of DeBERTa's use-cases employ short, batched input strings which diminishes the effect of memory efficient optimizations, nevertheless there are tasks, particularly entailment detection between complex sentences, which can create large input strings and very high memory costs.

His-Wardship avatar Apr 26 '24 10:04 His-Wardship

@His-Wardship - Thanks for your input. We have finally managed to merge #28802 (SDPA for BERT), so hopefully this will unlock other SDPA implementations for other models. I plan to open up a PR for RoBERTa and do some testing to see how if there are any performance gains there.

Regarding DeBERTa, I'm somewhat skeptical whether or not SDPA has support for its disentangled attention mechanism. I think the disentangled attention mechanism requires some extra operations that may not be supported in SDPA.

hackyon avatar Apr 26 '24 19:04 hackyon