torchtitan icon indicating copy to clipboard operation
torchtitan copied to clipboard

[Feature] Enable CUDNN Attention

Open TJ-Solergibert opened this issue 8 months ago • 5 comments

Recently PyTorch integrated the CUDNN attention backend to torch.nn.functional.scaled_dot_product_attention. I've tried torch.nn.functional.scaled_dot_product_attention in 2 different machines W/ H100s and in both machines it dispatches SDPBackend.FLASH_ATTENTION. Manually switching to SDPBackend.CUDNN_ATTENTION provides a ~10% e2e improvement.

# torchtitan/models/llama/model.py
from torch.nn.attention import sdpa_kernel, SDPBackend
...
with sdpa_kernel(SDPBackend.CUDNN_ATTENTION): 
    output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)

You can check some wandb runs here. I've tried a Llama3 8B model with 14 layers, batch size 1, sequence length 8192 and no activation checkpointing with 2 GPUs. I used this config in order to run the model with fewer GPUs & without recompilation. I tried to run the experiments with a single GPU but I was getting w/ different pytorch versions a error from scaled_dot_product_attention (No available kernel. Aborting execution.). I'm using 2.7.0.dev20250310+cu126.

cc @drisspg

PD: I could open a PR to select the SDPBackend as a experimental feature

TJ-Solergibert avatar Apr 02 '25 20:04 TJ-Solergibert

So there were some issues w/ the CuDNN backend that have mostly been solved. For the 2.8 release we are planning on enabling by default for H100 and B200.

@TJ-Solergibert

I tried to run the experiments with a single GPU but I was getting w/ different pytorch versions a error from scaled_dot_product_attention (No available kernel. Aborting execution.). I'm using 2.7.0.dev20250310+cu126.

Is this saying that you tried a run forcing CuDNN and it couldn't be run? Do you have a config or some way of reproing?

This error should come w/ some warning logs saying why it couldn't be run

drisspg avatar Apr 03 '25 00:04 drisspg

Hi @drisspg,

thanks for your quick response! The error trace is as follows:

[rank0]:[titan] 2025-04-03 13:32:26,173 - root - INFO - Training starts at step 1.
[rank0]:/iopsstor/scratch/cscs/asolergi/mm/torchtitan/torchtitan/models/llama/model.py:260: UserWarning: Memory efficient kernel not used because: (Triggered internally at /pytorch/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:838.)
[rank0]:  output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
[rank0]:/iopsstor/scratch/cscs/asolergi/mm/torchtitan/torchtitan/models/llama/model.py:260: UserWarning: Memory Efficient attention has been runtime disabled. (Triggered internally at /pytorch/aten/src/ATen/native/transformers/sdp_utils_cpp.h:548.)
[rank0]:  output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
[rank0]:/iopsstor/scratch/cscs/asolergi/mm/torchtitan/torchtitan/models/llama/model.py:260: UserWarning: Flash attention kernel not used because: (Triggered internally at /pytorch/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:840.)
[rank0]:  output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
[rank0]:/iopsstor/scratch/cscs/asolergi/mm/torchtitan/torchtitan/models/llama/model.py:260: UserWarning: Flash attention has been runtime disabled. (Triggered internally at /pytorch/aten/src/ATen/native/transformers/sdp_utils_cpp.h:536.)
[rank0]:  output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
[rank0]:/iopsstor/scratch/cscs/asolergi/mm/torchtitan/torchtitan/models/llama/model.py:260: UserWarning: CuDNN attention kernel not used because: (Triggered internally at /pytorch/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:842.)
[rank0]:  output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
[rank0]:/iopsstor/scratch/cscs/asolergi/mm/torchtitan/torchtitan/models/llama/model.py:260: UserWarning: Experimental cuDNN SDPA nested tensor support does not support backward. (Triggered internally at /pytorch/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:546.)
[rank0]:  output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
[rank0]:/iopsstor/scratch/cscs/asolergi/mm/torchtitan/torchtitan/models/llama/model.py:260: UserWarning: Expected query, key and value to all be of dtype: {Half, BFloat16}. Got Query dtype: float, Key dtype: float, and Value dtype: float instead. (Triggered internally at /pytorch/aten/src/ATen/native/transformers/sdp_utils_cpp.h:90.)
[rank0]:  output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)

In short, when running with a single GPU the Q, K & V matrices are in torch.float32 triggering this error when specifying both the SDPBackend.FLASH_ATTENTION & SDPBackend.CUDNN_ATTENTION backends. No error when NOT specifying any backend. When using FSDPv2, the Q, K & V matrices are in torch.bfloat16 so there is any issue!

TJ-Solergibert avatar Apr 03 '25 12:04 TJ-Solergibert

@TJ-Solergibert get_train_context() will specify the SDPABackend, including memory efficient, cudnn, and flash when CP is enabled. We can extend it to even CP is not enabled. As for the error you reported, it seems to be reasonable because with torch.float32 you can only use memory efficient not flash nor cudnn. What's the target parallelisms and configurations you are going to use but fail to do so?

fegin avatar Apr 03 '25 17:04 fegin

When using a single GPU using the 8B model with 14 layers instead of 32, without torch compile and with activation_checkpoint.mode = none. Basically, we don't apply any parallelism in parallelize_llama.

[parallelism]
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1
tensor_parallel_degree = 1
pipeline_parallel_degree = 1
context_parallel_degree = 1

TJ-Solergibert avatar Apr 03 '25 18:04 TJ-Solergibert

In such a case, you probably need autocast in get_train_context(). We apply all the mixed precision within parallelize_llama. With only one GPU, nothing is going to be added on top of the original model. I think train_context is a reasonable place to add this feature for single GPU.

fegin avatar Apr 03 '25 18:04 fegin