flash-attention
flash-attention copied to clipboard
flash-attention v2 with activation checkpointing (no_reentrant) raise Runtime Error
With both flash_attn_varlen_qkvpacked_func and CheckpointImpl.NO_REENTRANT raise Runtime Error below:
Traceback (most recent call last):
> File "/opt/tiger/antelope/train.py", line 718, in <module>
main()
└ <function main at 0x7f385c2679d0>
File "/opt/tiger/antelope/train.py", line 703, in main
train(
└ <function train at 0x7f385c267790>
File "/opt/tiger/antelope/train.py", line 503, in train
grad_scaler.scale(loss).backward()
│ │ └ tensor(3.9735, device='cuda:4', grad_fn=<DivBackward0>)
│ └ <function ShardedGradScaler.scale at 0x7f386a1dcca0>
└ <torch.distributed.fsdp.sharded_grad_scaler.ShardedGradScaler object at 0x7f38356e2e80>
File "/usr/local/lib/python3.9/dist-packages/torch/_tensor.py", line 487, in backward
torch.autograd.backward(
│ │ └ <function backward at 0x7f3877d575e0>
│ └ <module 'torch.autograd' from '/usr/local/lib/python3.9/dist-packages/torch/autograd/__init__.py'>
└ <module 'torch' from '/usr/local/lib/python3.9/dist-packages/torch/__init__.py'>
File "/usr/local/lib/python3.9/dist-packages/torch/autograd/__init__.py", line 200, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
│ │ └ <method 'run_backward' of 'torch._C._EngineBase' objects>
│ └ <torch._C._EngineBase object at 0x7f3878059350>
└ <class 'torch.autograd.variable.Variable'>
File "/usr/local/lib/python3.9/dist-packages/torch/autograd/function.py", line 274, in apply
return user_fn(self, *args)
│ │ └ (tensor([[[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
│ │ 0.0000e+00, 0.0000e+00],
│ │ [ 0.0000e+...
│ └ <torch.autograd.function.FlashAttnVarlenQKVPackedFuncBackward object at 0x7f2bd62bf220>
└ <function FlashAttnVarlenQKVPackedFunc.backward at 0x7f38500424c0>
File "/usr/local/lib/python3.9/dist-packages/flash_attn/flash_attn_interface.py", line 146, in backward
q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
│ └ <attribute 'saved_tensors' of 'torch._C._FunctionBase' objects>
└ <torch.autograd.function.FlashAttnVarlenQKVPackedFuncBackward object at 0x7f2bd62bf220>
RuntimeError: !grad_accumulator_.expired() INTERNAL ASSERT FAILED at "../torch/csrc/autograd/saved_variable.cpp":227, please report a bug to PyTorch. No grad accumulator for a saved leaf
I'm not familiar with autograd in pytorch, however the error seems similar to https://github.com/pytorch/pytorch/issues/103726 and https://github.com/pytorch/pytorch/issues/90481.
Pytorch version: v2.0.1
same question
@wjfwzzc Excuse me, I'm interesting about you traceback printing of you python, irrelavant about this issue though.. Could I know how to let python print the stack trace like yours?
@wjfwzzc Excuse me, I'm interesting about you traceback printing of you python, irrelavant about this issue though.. Could I know how to let python print the stack trace like yours?
https://github.com/Delgan/loguru
@Legend94rz I train stable diffusion and this problem appered。I think if you train any llm or stable diffusion model and open gradient_checkpointing,you will see the same problem
ping @tridao ,would you fix this issue or provide some workaround? FSDP + activation checkpointing is kind of a common setting for large transformer training.
I'm not familiar with FSDP, can you post a short script to replicate?
Is the issue just activation checkpointing? Or is FSDP relevant?
does anyone know if this has been fixed since then?
i think the issue applies to a raw torch.utils.checkpoint.checkpoint(..., use_reentrant=False) as well
does anyone know if this has been fixed since then?
i think the issue applies to a raw
torch.utils.checkpoint.checkpoint(..., use_reentrant=False)as well
I think the problem is still there. I met the error with the same function as yours.
Any update on this, getting same error