[BUG] MoE load balancing loss is accumulated twice when using activation checkpointing
Describe the bug Load balancing loss is accumulated twice when using activation checkpointing
To Reproduce
Train from scratch with / without --moe-layer-recompute, setting --moe-router-load-balancing-type aux_loss
Expected behavior Load balancing loss should be the same in the two settings (and should be slightly higher than 1 which means fully balanced)
Stack trace/logs
-
without
--moe-layer-recompute: iteration 10: load_balancing_loss: 1.091395E+00 iteration 20: load_balancing_loss: 1.096082E+00 iteration 30: load_balancing_loss: 1.037049E+00 -
with
--moe-layer-recompute: iteration 10: load_balancing_loss: 2.202137E+00 iteration 20: load_balancing_loss: 2.298303E+00 iteration 30: load_balancing_loss: 2.120842E+00
Environment (please complete the following information):
- Megatron-LM d4e72c0d33edc0c53aeb624f617eb77cebce6ae9
- PyTorch 2.4.1
- CUDA 12.1
- NCCL 2.20.5
Proposed fix
Replace if self.training with if self.training and torch.is_grad_enabled():.
Reason: When using activation checkpointing with --moe-layer-recompute, the forward function is executed twice. This leads to the load balancing loss being accumulated twice in TopKRouter.aux_loss_load_balancing within megatron/core/transformer/moe/router.py if the condition is only if self.training:. By changing the condition to if self.training and torch.is_grad_enabled():, the accumulation during the first pass (where gradients are not enabled) is prevented, while ensuring the standard training process without --moe-layer-recompute remains unaffected.
A similar issue occurs with z_loss.
The fix is included in the PR #1331.
Additional context N/A
Thanks for reporting and fixing this, this is likely a display bug with no impact on convergence. We'll take your PR internally and help get it merged.
I see a similar doubling when CUDA graphs are turned ON/OFF. Please see this for ref: https://github.com/NVIDIA/Megatron-LM/issues/1462#issuecomment-2732642584.
However, the above fix doesn't work with the graphs doubling.
Marking as stale. No activity in 60 days.
Hi @thuwzt, we've merged the fix with you as co-author in May. Sorry for the late notice, and thanks for your contribution! https://github.com/NVIDIA/Megatron-LM/commit/e6d56d6828c0773f55772b92b2ec0eed5639665e