Megatron-LM
Megatron-LM copied to clipboard
[BUG] wrong loss scaling when context parallel is on
Describe the bug
Hi, I think there is a bug when context parallel is on and we can discuss it. https://github.com/NVIDIA/Megatron-LM/blob/0bc3547702464501feefeb5523b7a17e591b21fa/pretrain_gpt.py#L148
From this issue,i know the result is same for dp2cp4 and dp8 with the same global batch_size.
But the code logic is different bewteen above issue and current code logic. In above issue logic, the loss scaling with cp_size and grad_data scaling with the world_size from get_data_parallel_group(with_context_parallel=True) In current code logic, the loss scaling with cp_size, but grad_data scaling with the world_size from get_data_parallel_group()
Two logic have different grad_data. (print the grad_data after allreduce it)
To Reproduce dp2cp4 and dp8 with same parameter can reproduce the result
Proposed fix remove the loss scaling with cp_size in loss_func