flash-attention icon indicating copy to clipboard operation
flash-attention copied to clipboard

Function 'FlashAttnQKVPackedFuncBackward' returned nan values in its 0th output

Open Karbo123 opened this issue 3 years ago • 7 comments
trafficstars

I found my training loss become NAN after training many epochs, but the model params were all finite (i.e. torch.isfinite) after the NAN loss occurred (I checked this by loading the saved checkpoint file from disk). I tried to resume the model from the checkpoint file (i.e. the checkpoint after the NAN happened), at the very beginning, the training process just seemed to be okay, but after several epochs, the NAN loss occurred.

I tried to set torch.autograd.set_detect_anomaly(True) and CUDA_LAUNCH_BLOCKING=1 to find out what happened, and the result showed that FlashAttnQKVPackedFuncBackward returned NAN values for its output.

It is very strange, because the model params are all finite, and the NAN happen after several epochs, so the training data should be okay. But if the backward pass produces NAN value, why the model params don't contain any NAN values? BTW, I didn't use any gradient clipping.

I also check all the model params, and they just seem to be fine. The maximum abs max value among all the params (except the non-learnable freqency weight) is 1.415, not too large.

Do you have any suggestions on this? Is the FlashAttnQKVPackedFunc numerically unstable? Thank you very much! Looking forward to your reply.

Karbo123 avatar Sep 01 '22 12:09 Karbo123

Thanks for the report. The function should be numerically stable.

Which commit of FlashAttention are you using? On which GPU? What are the dimensions of the attention?

In order for me to reproduce the issue, can you save the arguments that caused NaN and send it to me? That'd be very helpful. For example, you can add these lines right before the return statement of the backward function to save the tensors to the file nan_repro.pt:

        if dqkv.isnan().any():
            state_dict = {'dout': dout, 'qkv': qkv, 'out': out, 'softmax_lse': softmax_lse,
                          'cu_seqlens': cu_seqlens, 'max_seqlen': ctx.max_seqlen,
                          'dropout_p': ctx.dropout_p, 'softmax_scale': ctx.softmax_scale,
                          'causal': ctx.causal, 'rng_state': rng_state}
            torch.save('nan_repro.pt', state_dict)
            breakpoint()

Thanks for your help!

tridao avatar Sep 01 '22 16:09 tridao

I was having the same problem... Long story short, it looks like dq, dk and dv need to be zeroed-out, since they are used as accumulators? However, currently, flash_attention_interface allocates them via torch.empty_like. Setting them to 0 before flash_attn_cuda.bwd seems to have resolved the issue.

vadimcn avatar Nov 04 '22 04:11 vadimcn

I'm very curious about this. I think all the of values in dq, dk, dv should overwritten during the execution of the backward pass.

The only problematic scenario I could imagine is when q, k, v are longer than what cu_seqlens indicate. For example, if q has shape (10, nheads, headdim) where 10 is supposed to be the total batch * seqlen, then dq is allocated as torch.empty_like(q). If e.g. cu_seqlens = [0, 5, 8] (which says that the batch has 2 sequences, 1st sequence being stored in index 0 -> 4, and 2nd sequence being stored in 5 -> 7), then during the execution only values from indices 0->7 are written, and values from indices 8 -> 9 are not overwritten.

@vadimcn could you say more about your setting? Does it fall into this case?

tridao avatar Nov 04 '22 05:11 tridao

The only problematic scenario I could imagine is when q, k, v are longer than what cu_seqlens indicate.

Yes, this was indeed the case :facepalm:. Thanks for the hint!

vadimcn avatar Nov 05 '22 18:11 vadimcn

@tridao I can reproduce this error after a very short time using the script below. Is this a user-side usage error?

import torch
import torch.optim as optim
import torch.nn.functional as F
from functools import partial
from flash_attn.modules.block import Block
from flash_attn.modules.mha import MHA
from flash_attn.modules.mlp import Mlp

embed_dim = 256 # !!! change this to 64 and the error will not be observable !!!
batch_size = 16
num_heads = 8
seq_length = 512
dim_feedforward = 1024
learning_rate = 0.01
device = torch.device("cuda")
torch.set_default_dtype(torch.bfloat16)
torch.autograd.set_detect_anomaly(True)

# Initialize model
model = Block( # TransformerEncoderLayer
        embed_dim,
        mixer_cls=partial(
            MHA,
            num_heads=num_heads,
            use_flash_attn=True,
            rotary_emb_dim=0,
        ),
        mlp_cls=partial(Mlp, hidden_features=dim_feedforward),
        resid_dropout1=0.0,
        resid_dropout2=0.0,
        prenorm=False,
    ).to(device)

optimizer = optim.Adam(model.parameters(), lr=learning_rate)
inputs = torch.full((batch_size, seq_length, embed_dim), 1000.0, device=device)

# Training loop
for i in range(999999):
    print(f'Iteration {i + 1}')
    optimizer.zero_grad()
    output = model(inputs)
    loss = output.mean()
    loss.backward()
    optimizer.step()

Output:

/home/otto/Development/temp/venv/lib/python3.10/site-packages/flash_attn/ops/triton/layer_norm.py:959: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
  def forward(
/home/otto/Development/temp/venv/lib/python3.10/site-packages/flash_attn/ops/triton/layer_norm.py:1018: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
  def backward(ctx, dout, *args):
Iteration 1
Iteration 2
Iteration 3
Iteration 4
Iteration 5
Iteration 6
Iteration 7
Iteration 8
Iteration 9
Iteration 10
Iteration 11
Iteration 12
Iteration 13
Iteration 14
Iteration 15
/home/otto/Development/temp/venv/lib/python3.10/site-packages/torch/autograd/graph.py:825: UserWarning: Error detected in FlashAttnQKVPackedFuncBackward. Traceback of forward call that caused the error:
  File "/home/otto/Development/temp/test.py", line 41, in <module>
    output = model(inputs)
  File "/home/otto/Development/temp/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/otto/Development/temp/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/otto/Development/temp/venv/lib/python3.10/site-packages/flash_attn/modules/block.py", line 195, in forward
    mixer_out = self.mixer(
  File "/home/otto/Development/temp/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/otto/Development/temp/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/otto/Development/temp/venv/lib/python3.10/site-packages/flash_attn/modules/mha.py", line 670, in forward
    context = self.inner_attn(qkv, **kwargs)
  File "/home/otto/Development/temp/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/otto/Development/temp/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/otto/Development/temp/venv/lib/python3.10/site-packages/flash_attn/modules/mha.py", line 122, in forward
    return flash_attn_qkvpacked_func(
  File "/home/otto/Development/temp/venv/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 729, in flash_attn_qkvpacked_func
    return FlashAttnQKVPackedFunc.apply(
  File "/home/otto/Development/temp/venv/lib/python3.10/site-packages/torch/autograd/function.py", line 575, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
 (Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:110.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Traceback (most recent call last):
  File "/home/otto/Development/temp/test.py", line 43, in <module>
    loss.backward()
  File "/home/otto/Development/temp/venv/lib/python3.10/site-packages/torch/_tensor.py", line 581, in backward
    torch.autograd.backward(
  File "/home/otto/Development/temp/venv/lib/python3.10/site-packages/torch/autograd/__init__.py", line 347, in backward
    _engine_run_backward(
  File "/home/otto/Development/temp/venv/lib/python3.10/site-packages/torch/autograd/graph.py", line 825, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Function 'FlashAttnQKVPackedFuncBackward' returned nan values in its 0th output.

Latest published version of flash-attention.

otto-dev avatar Sep 13 '24 02:09 otto-dev

Does the same thing happen if you use standard implementation of attention? i.e. try use_flash_attn=False

tridao avatar Sep 13 '24 03:09 tridao

use_flash_attn=False

Then it works fine @tridao

otto-dev avatar Sep 13 '24 04:09 otto-dev

@otto-dev @tridao any updates? How to fix it?

Oktai15 avatar Oct 21 '24 17:10 Oktai15

@tridao I can reproduce this error after a very short time using the script below. Is this a user-side usage error?

import torch
import torch.optim as optim
import torch.nn.functional as F
from functools import partial
from flash_attn.modules.block import Block
from flash_attn.modules.mha import MHA
from flash_attn.modules.mlp import Mlp

embed_dim = 256 # !!! change this to 64 and the error will not be observable !!!
batch_size = 16
num_heads = 8
seq_length = 512
dim_feedforward = 1024
learning_rate = 0.01
device = torch.device("cuda")
torch.set_default_dtype(torch.bfloat16)
torch.autograd.set_detect_anomaly(True)

# Initialize model
model = Block( # TransformerEncoderLayer
        embed_dim,
        mixer_cls=partial(
            MHA,
            num_heads=num_heads,
            use_flash_attn=True,
            rotary_emb_dim=0,
        ),
        mlp_cls=partial(Mlp, hidden_features=dim_feedforward),
        resid_dropout1=0.0,
        resid_dropout2=0.0,
        prenorm=False,
    ).to(device)

optimizer = optim.Adam(model.parameters(), lr=learning_rate)
inputs = torch.full((batch_size, seq_length, embed_dim), 1000.0, device=device)

# Training loop
for i in range(999999):
    print(f'Iteration {i + 1}')
    optimizer.zero_grad()
    output = model(inputs)
    loss = output.mean()
    loss.backward()
    optimizer.step()

Output:

/home/otto/Development/temp/venv/lib/python3.10/site-packages/flash_attn/ops/triton/layer_norm.py:959: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
  def forward(
/home/otto/Development/temp/venv/lib/python3.10/site-packages/flash_attn/ops/triton/layer_norm.py:1018: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
  def backward(ctx, dout, *args):
Iteration 1
Iteration 2
Iteration 3
Iteration 4
Iteration 5
Iteration 6
Iteration 7
Iteration 8
Iteration 9
Iteration 10
Iteration 11
Iteration 12
Iteration 13
Iteration 14
Iteration 15
/home/otto/Development/temp/venv/lib/python3.10/site-packages/torch/autograd/graph.py:825: UserWarning: Error detected in FlashAttnQKVPackedFuncBackward. Traceback of forward call that caused the error:
  File "/home/otto/Development/temp/test.py", line 41, in <module>
    output = model(inputs)
  File "/home/otto/Development/temp/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/otto/Development/temp/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/otto/Development/temp/venv/lib/python3.10/site-packages/flash_attn/modules/block.py", line 195, in forward
    mixer_out = self.mixer(
  File "/home/otto/Development/temp/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/otto/Development/temp/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/otto/Development/temp/venv/lib/python3.10/site-packages/flash_attn/modules/mha.py", line 670, in forward
    context = self.inner_attn(qkv, **kwargs)
  File "/home/otto/Development/temp/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/otto/Development/temp/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/otto/Development/temp/venv/lib/python3.10/site-packages/flash_attn/modules/mha.py", line 122, in forward
    return flash_attn_qkvpacked_func(
  File "/home/otto/Development/temp/venv/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 729, in flash_attn_qkvpacked_func
    return FlashAttnQKVPackedFunc.apply(
  File "/home/otto/Development/temp/venv/lib/python3.10/site-packages/torch/autograd/function.py", line 575, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
 (Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:110.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Traceback (most recent call last):
  File "/home/otto/Development/temp/test.py", line 43, in <module>
    loss.backward()
  File "/home/otto/Development/temp/venv/lib/python3.10/site-packages/torch/_tensor.py", line 581, in backward
    torch.autograd.backward(
  File "/home/otto/Development/temp/venv/lib/python3.10/site-packages/torch/autograd/__init__.py", line 347, in backward
    _engine_run_backward(
  File "/home/otto/Development/temp/venv/lib/python3.10/site-packages/torch/autograd/graph.py", line 825, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Function 'FlashAttnQKVPackedFuncBackward' returned nan values in its 0th output.

Latest published version of flash-attention.

Is there any progress?

I meet the similar issue in my self-code, so I search and run into this issue.

I tries to debug from this code snippets and my own training code. Here is the environ and some testing

  • hardware-env: A100 with debian 5 in linux 5.4, Driver Version: 470.129.06 CUDA Version: 12.4
  • torch version and flash-attn version: pytorch2.5.1 + flash-attn 2.7.2 or pytorch 2.3.2 + flash-attn 2.5.8
  • the trying I have done:
  • use flash-attn, deterministic=False, the code snippets by @otto-dev will raise nan error in FlashAttnQKVPackedFunc with about 100 rounds(the iteration number is not stable in different run), and in my own model training code, it will also raise nan error in FlashAttnQKVPackedFunc but in the first round. Besides, If I run the same training code for my model repeatly 2~5 times, the nan error will disappear, and the training loop will work very nice in about 1w training iters up to now.
  • use flash-attn, deterministic=Ture the code snippets by @otto-dev will raise nan error in NativeLayerNormBackward0 rather than flash-attn with about 100 rounds(the iteration number is not stable in different run). In my own model training code, it will work very vice in about 5k training iters up to now

from those phenomenon mentioned above, I guess the undeterministic cuda kernel in flash-backward resulted in the nan in backward. I hope it would help @tridao . And I will continue working on this issue, so could I open a PR if I find the bug?

ValMystletainn avatar Dec 31 '24 12:12 ValMystletainn

...

  • hardware-env: A100 with debian 5 in linux 5.4, Driver Version: 470.129.06 CUDA Version: 12.4
  • torch version and flash-attn version: pytorch2.5.1 + flash-attn 2.7.2 or pytorch 2.3.2 + flash-attn 2.5.8
  • the trying I have done:
  • use flash-attn, deterministic=False, the code snippets by @otto-dev will raise nan error in FlashAttnQKVPackedFunc with about 100 rounds(the iteration number is not stable in different run), and in my own model training code, it will also raise nan error in FlashAttnQKVPackedFunc but in the first round. Besides, If I run the same training code for my model repeatly 2~5 times, the nan error will disappear, and the training loop will work very nice in about 1w training iters up to now.
  • use flash-attn, deterministic=Ture the code snippets by @otto-dev will raise nan error in NativeLayerNormBackward0 rather than flash-attn with about 100 rounds(the iteration number is not stable in different run). In my own model training code, it will work very vice in about 5k training iters up to now

from those phenomenon mentioned above, I guess the undeterministic cuda kernel in flash-backward resulted in the nan in backward. I hope it would help @tridao . And I will continue working on this issue, so could I open a PR if I find the bug?

the backward nan in the script from @otto-dev should be caused by learning rate too large, rather than the fa2 operator.( at least in my several environments) . In the 1e-2 setting, the output tensor before the final norm become larger, and set breakpoint by if dq.isnan().any() ... shows that the qkv max norm is near the bound of fp16/bf16.

besides, Just turn the learning rate to 1e-3, the scripts will run safely forever. so may @otto-dev check it again with the smaller learning rate?

ValMystletainn avatar Jan 04 '25 08:01 ValMystletainn

I'm seeing a similar error in my training run with quite a small learning rate too. (5e-5)

sjoshi804 avatar Aug 07 '25 21:08 sjoshi804