diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

torch 2.5 CuDNN backend for SDPA NaN error

Open wtyuan96 opened this issue 4 months ago • 2 comments

Describe the bug

When using the recently released PyTorch 2.5, the default SDPA backend is CUDNN_ATTENTION. In the example's CogVideoX-lora training script, NaN gradients occur right at the first step. However, using other SDPA backends, such as FLASH_ATTENTION or EFFICIENT_ATTENTION, does not lead to NaN issues.

After some preliminary investigation, I found that this might be related to the transpose and reshape operations following the SDPA computation (see L1954).

        hidden_states = F.scaled_dot_product_attention(
            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
        )

        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)

some related issues and PRs: https://github.com/pytorch/pytorch/issues/134001 https://github.com/pytorch/pytorch/pull/134031 https://github.com/pytorch/pytorch/pull/138354

Furthermore, I discovered that other attention processors in attention_processor.py also utilize the same transpose and reshape operations, such as FluxAttnProcessor2_0, which could potentially lead to similar problems.

Reproduction

This issue can be reproduced by setting a breakpoint after gradient backward and then printing the gradients:

loss = torch.mean((weights * (model_pred - target) ** 2).reshape(batch_size, -1), dim=1)
loss = loss.mean()
accelerator.backward(loss)
print([[name, param.grad] for name, param in transformer.named_parameters() if param.requires_grad])

Change the default backend for SDPA to FLASH_ATTENTION or EFFICIENT_ATTENTION in attention_processor.py, and the NaN issue will not occur.

from torch.nn.attention import SDPBackend, sdpa_kernel                                                                                                                                             
with sdpa_kernel(SDPBackend.FLASH_ATTENTION): # or EFFICIENT_ATTENTION                                                                                                                                          
    hidden_states = F.scaled_dot_product_attention(                                                                                                                                                
        query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False                                                                                                                
    )                                                                                                                                                                                              
                                                                                                                                                                                                   
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)

Considering that PyTorch 2.5 is currently the default version available for installation, this issue may require some attention.

Logs

No response

System Info

  • 🤗 Diffusers version: 0.32.0.dev0
  • Platform: Linux-5.10.134-010
  • Running on Google Colab?: No
  • Python version: 3.10.15
  • PyTorch version (GPU?): 2.5.0+cu124 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.26.1
  • Transformers version: 4.46.0
  • Accelerate version: 1.0.1
  • PEFT version: 0.13.2
  • Bitsandbytes version: not installed
  • Safetensors version: 0.4.5
  • xFormers version: not installed
  • Accelerator: NVIDIA H20, 97871 MiB NVIDIA H20, 97871 MiB NVIDIA H20, 97871 MiB NVIDIA H20, 97871 MiB NVIDIA H20, 97871 MiB NVIDIA H20, 97871 MiB NVIDIA H20, 97871 MiB NVIDIA H20, 97871 MiB
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: No

Who can help?

@DN6 @a-r-r-o-w @yiyixuxu @sayakpaul

wtyuan96 avatar Oct 25 '24 04:10 wtyuan96