Megatron-LM
Megatron-LM copied to clipboard
[BUG] pipeline_paralle is not available when pp_size > 2
This func is wrong, the program will hang because of the "group" variable.
def _batched_p2p_ops( *, tensor_send_prev: Optional[torch.Tensor], tensor_recv_prev: Optional[torch.Tensor], tensor_send_next: Optional[torch.Tensor], tensor_recv_next: Optional[torch.Tensor], group: torch.distributed.ProcessGroup )
after modified:
def _batched_p2p_ops( *, tensor_send_prev: Optional[torch.Tensor], tensor_recv_prev: Optional[torch.Tensor], tensor_send_next: Optional[torch.Tensor], tensor_recv_next: Optional[torch.Tensor] )