cudnn-frontend icon indicating copy to clipboard operation
cudnn-frontend copied to clipboard

What's the difference of flash attention implement between cudnn and Dao-AILab?

Open MoFHeka opened this issue 6 months ago • 15 comments

Is this link a flash attention?

MoFHeka avatar Dec 20 '23 07:12 MoFHeka

There is no difference in algorithm and numerics of cudnn and Dao-AILab. The implementation in cudnn benefits from the in-house expertise at kernel development in Nvidia and aims to maximize the hardware capabilities.

Please use the following samples as code snippets to use cudnn flash attention: CPP : https://github.com/NVIDIA/cudnn-frontend/blob/1.0/release/samples/cpp/mha.cpp Python : https://github.com/NVIDIA/cudnn-frontend/blob/1.0/release/samples/python/test_mhas.py Documentation : https://github.com/NVIDIA/cudnn-frontend/blob/1.0/release/docs/operations/Attention.md

gautam20197 avatar Dec 20 '23 19:12 gautam20197

@gautam20197 As far as I know, flash attention has been implemented by nvidia in tensorflow, right? cuda_dnn.cc

MoFHeka avatar Dec 21 '23 04:12 MoFHeka

@MoFHeka , it is not correct to say it is implemented in tensorflow, it is implemented in XLA and there is a PR https://github.com/openxla/xla/pull/6872 pending to integrate the final piece of flash attention in XLA. Once this PR is merged, you can access flash attention from JAX/Tensorflow if the pattern is supported.

Cjkkkk avatar Jan 02 '24 20:01 Cjkkkk

@Cjkkkk So if I understand correctly, in addition to TF/Jax, Pytorch can also use OpenXla to work with cudnn.

MoFHeka avatar Jan 05 '24 06:01 MoFHeka

@MoFHeka PyTorch eager mode has a path to cuDNN's optimized attention.

mnicely avatar Feb 21 '24 16:02 mnicely

I think we've addressed the original question. Going to close for now

mnicely avatar Feb 21 '24 16:02 mnicely

Is there any benchmark between CuDNN fused attention and flash attention? Recently I found TorchACC has already supported using CuDNN fused attention in PyTorch training. So there's definitely a benchmark, right? Even a C++ code end-to-end performance. @mnicely @Cjkkkk @gautam20197

I am eager to know how I should align whether the acceleration after I turn on xla has reached the ideal state.

MoFHeka avatar Mar 22 '24 13:03 MoFHeka

I think you can check your use case using the PyTorch nightlies. pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121

And running the PyTorch SDPA example https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html

Using TORCH_CUDNN_SDPA_ENABLED=1

mnicely avatar Mar 22 '24 15:03 mnicely

@mnicely Thank you very much for your answer. May I ask how much improvement has been made compared to Dao-AILab flash attention 2 according to your evaluation?

MoFHeka avatar Mar 22 '24 16:03 MoFHeka

We recently release cuDNN V9. FP16 and BF16 fused flash attention engine performance has been significantly improved for NVIDIA GPUs:

  • Speed-up of up to 50% over cuDNN 8.9.7 on Hopper GPUs.
  • Speed-up of up to 100% over cuDNN 8.9.7 on Ampere GPUs.

We say up to because it depends on the parameters.

mnicely avatar Mar 22 '24 17:03 mnicely

@mnicely I have noticed that speed-up benchmark at cudnn release note recently. Yes, it looks perfect. But is there any more details for QKV shape and something else. A single acceleration in a particular situation is not convincing enough, we need a repeatable experiment scenario.

MoFHeka avatar Mar 24 '24 03:03 MoFHeka

@MoFHeka The problem sizes with hidden dimension per head (d) = 128 are the best to gain a significant speedup for both Hopper and Ampere.

gautam20197 avatar Mar 25 '24 16:03 gautam20197

@gautam20197 head (d) = 128 with any batch size or sequence length?

MoFHeka avatar Mar 25 '24 17:03 MoFHeka

Yes there will be healthy speedup for all batches and sequence lengths.

gautam20197 avatar Mar 25 '24 19:03 gautam20197

I think you can check your use case using the PyTorch nightlies. pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121

And running the PyTorch SDPA example https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html

Using TORCH_CUDNN_SDPA_ENABLED=1

@mnicely Unable to run

>>> import torch
>>> t1=torch.randn(1,4,4096,128).to("cuda").to(torch.float16)
>>> torch._scaled_dot_product_cudnn_attention(t1, t1, t1, dropout_p=0.0, is_causal=True)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: mha_graph->execute(handle, variant_pack, workspace_ptr.get()).is_good() INTERNAL ASSERT FAILED at "/opt/conda/conda-bld/pytorch_1711266070736/work/aten/src/ATen/native/cudnn/MHA.cpp":410, please report a bug to PyTorch.

MoFHeka avatar Mar 26 '24 07:03 MoFHeka