[BUG] p2p communication order error and stuck when pp 2 and vpp 2 with remove pad
Describe the bug p2p communication order error and stuck when pp 2 and vpp 2 with remove pad
To Reproduce
When use PP=2 and VPP=2 with config.variable_seq_lengths=True, config.batch_p2p_comm=True and config.overlap_p2p_comm=False, current implementation of p2p_communication.py will cause incorrect behavior.
If we set config.overlap_p2p_comm=True and config.batch_p2p_comm=False, bug disappear.
You can use verl to reproduce this issues, you should set actor's/critic's pipeline_model_parallel_size=2 and virtual_pipeline_model_parallel_size=1, and deleteconfig.overlap_p2p_comm and config.batch_p2p_comm in verl/utils/megatron_utils.py to use original Megatron-LM configuration.
Expected behavior Like this image below:
After 2 devices finish at the dashed time, Device 1 should pass output_tensor and input_tensor_grad to Device 2, and because world size is 2, both devices have the same next_rank and prev_rank, the original ring communication becomes intercommunication, thus cause conflicts in p2p_communication. In detail, Device 1 passes output_tensor to next_rank and input_tensor_grad to prev_rank, and Device 2 receives output_tensor_grad from next_rank and input_tensor from prev_rank.
Stack trace/logs
Here is more log:
# Device 0
send_prev_shape_tensor: torch.Tensor([1673, 1, 3840], device='cuda:0'), send_next_shape_tensor: torch.Tensor([1702, 1, 3840], device='cuda:0')
recv_prev_shape_tensor: torch.Tensor([], device='cuda:0'), recv_next_shape_tensor: torch.Tensor([1664, 1, 3840], device='cuda:0')
# Device 1
send_prev_shape_tensor: torch.Tensor([1664, 1, 3840], device='cuda:0'), send_next_shape_tensor: torch.Tensor([1653, 1, 3840], device='cuda:0')
recv_prev_shape_tensor: torch.Tensor([1673, 1, 3840], device='cuda:0'), recv_next_shape_tensor: torch.Tensor([1702, 1, 3840], device='cuda:0') # Reverse Error
Environment (please complete the following information):
- Megatron-LM core_r0.11.0
- PyTorch 2.4.0
- CUDA 12.4
- NCCL 2.20.5
Proposed fix PR see #1451
Additional context Add any other context about the problem here.
Marking as stale. No activity in 60 days.