cudnn-frontend
cudnn-frontend copied to clipboard
What's the difference of flash attention implement between cudnn and Dao-AILab?
Is this link a flash attention?
There is no difference in algorithm and numerics of cudnn and Dao-AILab. The implementation in cudnn benefits from the in-house expertise at kernel development in Nvidia and aims to maximize the hardware capabilities.
Please use the following samples as code snippets to use cudnn flash attention: CPP : https://github.com/NVIDIA/cudnn-frontend/blob/1.0/release/samples/cpp/mha.cpp Python : https://github.com/NVIDIA/cudnn-frontend/blob/1.0/release/samples/python/test_mhas.py Documentation : https://github.com/NVIDIA/cudnn-frontend/blob/1.0/release/docs/operations/Attention.md
@gautam20197 As far as I know, flash attention has been implemented by nvidia in tensorflow, right? cuda_dnn.cc
@MoFHeka , it is not correct to say it is implemented in tensorflow, it is implemented in XLA and there is a PR https://github.com/openxla/xla/pull/6872 pending to integrate the final piece of flash attention in XLA. Once this PR is merged, you can access flash attention from JAX/Tensorflow if the pattern is supported.
@Cjkkkk So if I understand correctly, in addition to TF/Jax, Pytorch can also use OpenXla to work with cudnn.
@MoFHeka PyTorch eager mode has a path to cuDNN's optimized attention.
I think we've addressed the original question. Going to close for now
Is there any benchmark between CuDNN fused attention and flash attention? Recently I found TorchACC has already supported using CuDNN fused attention in PyTorch training. So there's definitely a benchmark, right? Even a C++ code end-to-end performance. @mnicely @Cjkkkk @gautam20197
I am eager to know how I should align whether the acceleration after I turn on xla has reached the ideal state.
I think you can check your use case using the PyTorch nightlies.
pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121
And running the PyTorch SDPA example https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html
Using TORCH_CUDNN_SDPA_ENABLED=1
@mnicely Thank you very much for your answer. May I ask how much improvement has been made compared to Dao-AILab flash attention 2 according to your evaluation?
We recently release cuDNN V9. FP16 and BF16 fused flash attention engine performance has been significantly improved for NVIDIA GPUs:
- Speed-up of up to 50% over cuDNN 8.9.7 on Hopper GPUs.
- Speed-up of up to 100% over cuDNN 8.9.7 on Ampere GPUs.
We say up to because it depends on the parameters.
@mnicely I have noticed that speed-up benchmark at cudnn release note recently. Yes, it looks perfect. But is there any more details for QKV shape and something else. A single acceleration in a particular situation is not convincing enough, we need a repeatable experiment scenario.
@MoFHeka The problem sizes with hidden dimension per head (d) = 128 are the best to gain a significant speedup for both Hopper and Ampere.
@gautam20197 head (d) = 128 with any batch size or sequence length?
Yes there will be healthy speedup for all batches and sequence lengths.
I think you can check your use case using the PyTorch nightlies.
pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121
And running the PyTorch SDPA example https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html
Using
TORCH_CUDNN_SDPA_ENABLED=1
@mnicely Unable to run
>>> import torch
>>> t1=torch.randn(1,4,4096,128).to("cuda").to(torch.float16)
>>> torch._scaled_dot_product_cudnn_attention(t1, t1, t1, dropout_p=0.0, is_causal=True)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: mha_graph->execute(handle, variant_pack, workspace_ptr.get()).is_good() INTERNAL ASSERT FAILED at "/opt/conda/conda-bld/pytorch_1711266070736/work/aten/src/ATen/native/cudnn/MHA.cpp":410, please report a bug to PyTorch.