Megatron-LM
Megatron-LM copied to clipboard
[BUG] Passed the wrong type of argument to torch.distributed.broadcast.
Describe the bug
def broadcast_params(self):
"""
Syncs parameters across all DP ranks.
"""
for param in self.module.parameters():
is_expert_parallel = not getattr(param, 'allreduce', True)
if is_expert_parallel:
torch.distributed.broadcast(
param.data,
src=torch.distributed.get_process_group_ranks(self.expert_data_parallel_group),
group=self.expert_data_parallel_group,
)
else:
torch.distributed.broadcast(
param.data,
src=torch.distributed.get_process_group_ranks(self.data_parallel_group),
group=self.data_parallel_group,
)
The src
parameter of torch.distributed.broadcast
should be of type int
, indicating the root from which to broadcast. However, in the above code, the passed parameter is a list of all ranks in the data parallel group.
The above code snippet is from the DistributedDataParallel
class in megatron/core/distributed/distributed_data_parallel.py
.
To Reproduce N/A
Expected behavior N/A
Stack trace/logs N/A
Environment (please complete the following information):
- Megatron-LM commit ID: f3a3020
Proposed fix
The rank-0
of the data parallel group should be passed in.
Additional context N/A