[BUG]: Got nan during backward with zero2
Is there an existing issue for this bug?
- [X] I have searched the existing issues
🐛 Describe the bug
My code is based on Open-Sora, and can run without any issue on 32 gpus, using zero2.
However, when using 64 gpus, nan appears in the tensor gradients after the second backward step.
I have made a workaround to patch colossalai/zero/low_level/low_level_optim.py with
# line 313, in _run_reduction
flat_grads = bucket_store.get_flatten_grad()
flat_grads /= bucket_store.world_size
if torch.isnan(flat_grads).any(): # here
raise RuntimeError(f"rank {dist.get_rank()} got nan on flat_grads") # here
...
if received_grad.dtype != grad_dtype:
received_grad = received_grad.to(grad_dtype)
if torch.isnan(received_grad).any(): # here
raise RuntimeError(f"rank {dist.get_rank()} got nan on received_grad") # here
...
With the patch above, my code run normally and the loss seems fine.
I think it may related to asynchronized state between cuda streams. I do not the exact reason and I do not think my workaround could really solve the issue.
Any idea from the team member?
Environment
Nvidia H20
ColossalAI version: 0.4.3 cuda 12.4 pytorch 2.4
FP16 ZeRO should auto-check for overflow and skip that step, though this seems unimplemented for bf16. @botbw would you be able to take a look? I haven't been maintaining this part.
Since you're using Open-Sora, feel free to open this kind of issue in their repo too.
Indeed bf16 has the same range as fp32, but in my opinion this check can be enforced on all precisions?
Hi @Edenzzzz, thank you for involving.
Please note that by adding the mentioned lines, nan will not occur again and the RuntimeError are never raised by these lines. Therefore, I don’t think skipping a specific iteration could help. I suspect the bug is from communication. That’s also the reason I open an issue in this repo.
Emm then I think isnan triggers synchronization? You should check whether nan is in received_grad or flat_grads. You can also try removing those lines and put torch.cuda.synchronize().
dist.reduce_scatter is synchronous here and received_grad is init to zero, not nan, so comm might not be the issue.
@flymin Thanks for reporting this! Will it be possible to share the config/code snippet you are using? If not, could you try setting overlap_communication=False in LowLevelZeroPlugin and check if the problem still exists? After which we can conclude there might be bugs with communication stream synchronization.
Sorry I cannot provide my current code. I may have some time late in the Nov. to work on this issue again.
I have tried adding synchronization code in this function but it does not help. I also tried overlap_communication=False, the issue still exists.
From my workaround, I have to add two if blocks. Either single block cannot help.
If you suspect comm issues, you can put dist.barrier.
You should also try printing or inserting dist.breakpoint() to find out the first variable that becomes nan.
I've also used isnan for pipeline parallel debugging, but didn't see it doing any special synchronization... In my case grad was not received somewhere, and was further propagated across layers to create nan
barrier does not help. I will try dist.breakpoint() later. Thank you for your advice.