Where do dispatch and combine need to be synchronized?
I'd like to know where synchronization across all ranks is required for both dispatch and combine operations, using the following code that calls low_latency_dispatch and low_latency_combine as an example. Specifically:
Is synchronization across all ranks needed before dispatching SEND/RECV operations? Is synchronization across all ranks needed after dispatching SEND/RECV operations? Is synchronization across all ranks needed before combining SEND/RECV operations? Is synchronization across all ranks needed after combining SEND/RECV operations?
I would greatly appreciate it if you could also explain the reasoning behind whether synchronization is required at these positions. Thank you!
def forward(self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.LongTensor,
up_weights: torch.Tensor,
up_scale: torch.Tensor,
down_weights: torch.Tensor,
down_scale: torch.Tensor,
expert_list: List[int] = None):
# dispatch
packed_recv_hidden, masked_m, self.handle, event, hook = (self.buffer_low_latency.low_latency_dispatch(
hidden_states,
topk_idx,
self.num_max_dispatch_tokens_per_rank,
num_experts,
use_fp8=True,
async_finish=not self.return_recv_hook,
return_recv_hook=self.return_recv_hook,
))
hook() if self.return_recv_hook else event.current_stream_wait()
# compute
out_states = self.experts.forward(recv_hidden_states, up_weights, up_scale, down_weights, down_scale, masked_m, expected_m)
# combine
combined_hidden_states, event, hook = (self.buffer_low_latency.low_latency_combine(
hidden_states,
topk_idx,
topk_weights.to(torch.float32),
self.handle,
async_finish=not self.return_recv_hook,
return_recv_hook=self.return_recv_hook,
))
hook() if self.return_recv_hook else event.current_stream_wait()
I tried to find where the synchronization is implemented by looking at the code, but I still don't fully understand. Your guidance would be of great help to me! Thanks very much!!! 🥺🥺🥺 @LyricZhao @fzyzcjy
Is synchronization across all ranks needed before dispatching SEND/RECV operations? Is synchronization across all ranks needed after dispatching SEND/RECV operations? Is synchronization across all ranks needed before combining SEND/RECV operations? Is synchronization across all ranks needed after combining SEND/RECV operations?
No for all, no synchronization is needed.
For low-latency kernels, dispatch/combine will just wait for data arrival, there is no bidirectional synchronization (or barrier).
For example, the only wait-data-arrival of dispatch is here: https://github.com/deepseek-ai/DeepEP/blob/main/csrc/kernels/internode_ll.cu#L492.
For more information, see #166.
Is synchronization across all ranks needed before dispatching SEND/RECV operations? Is synchronization across all ranks needed after dispatching SEND/RECV operations? Is synchronization across all ranks needed before combining SEND/RECV operations? Is synchronization across all ranks needed after combining SEND/RECV operations?
No for all, no synchronization is needed.
For low-latency kernels, dispatch/combine will just wait for data arrival, there is no bidirectional synchronization (or barrier).
Thanks!!!!! I also have a question about end-to-end latency. Previously, we discussed the scenario where all RANKs start Dispatch/Combine at the same time. @LyricZhao
Taking the inference pipeline of an LLM as an example: Dispatch → Expert Group Gemm → Combine. Assume each RANK is allocated an equal number of tokens, meaning the latency of "Expert Group Gemm" is identical across RANKs. A question I'm concerned about is: What changes will occur in the end-to-end latency of each RANK? Can it be estimated as max(Dispatch latency) + Expert Group Gemm latency + max(Combine latency)?
If RANK 1 has more tokens and RANK 2 has fewer, then during the Dispatch phase, RANK 1 receives fewer tokens, resulting in lower Dispatch latency and allowing it to start Expert Group Gemm calculations earlier. However, when RANK 1 needs to begin Combine, it must wait for other RANKs to complete Expert Group Gemm. This way, RANK1 can receive the tokens sent by other RANKs. RANK 2 lags behind RANK 1 by (RANK2 Dispatch latency - RANK1 Dispatch latency). Does this imply the end-to-end time of RANK1/RANK2 should be RANK2 Dispatch latency + Expert Group Gemm latency + RANK1 Combine latency?
What changes will occur in the end-to-end latency of each RANK? Can it be estimated as max(Dispatch latency) + Expert Group Gemm latency + max(Combine latency)?
In such an ideal scenario (assuming attention is also same), all the lantency should be stable and same. No changes, the time shoule be dispatch latency (data amount / bandwidth) + gemm latency + combine latency (data amount / bandwidth).
Does this imply the end-to-end time of RANK1/RANK2 should be RANK2 Dispatch latency + Expert Group Gemm latency + RANK1 Combine latency?
Yes, you are right.
What changes will occur in the end-to-end latency of each RANK? Can it be estimated as max(Dispatch latency) + Expert Group Gemm latency + max(Combine latency)?
In such an ideal scenario (assuming attention is also same), all the lantency should be stable and same. No changes, the time shoule be dispatch latency (data amount / bandwidth) + gemm latency + combine latency (data amount / bandwidth).
Does this imply the end-to-end time of RANK1/RANK2 should be RANK2 Dispatch latency + Expert Group Gemm latency + RANK1 Combine latency?
Yes, you are right.
Thank you very much for your detailed answer! I would like to confirm whether my understanding is correct. I am analyzing the impact of unbalanced batch sizes (affecting dispatch/combine latency) on end-to-end latency. This is very important to me. To make it easier to understand, I have drawn two diagrams to illustrate my question. @LyricZhao
(1) Batch Size Unbalance
(2) Attention Unbalance, Batch Size Balance
Based on Attention DP and Expert EP without enabling two micro-batch settings, I have four questions:
(1) When attentions across different RANKs are balanced but Batch Sizes are unbalanced, is the End-to-End Latency calculated as shown in Figure 1:
End_to_End Latency = Attention + max(RANK1 Dispatch, RANK0 Dispatch) + GroupGemm + max(RANK1 Combine, RANK0 Combine)?
(2) When attentions across different RANKs are unbalanced but Batch Sizes are balanced, is the End-to-End Latency calculated as shown in Figure 2:
End_to_End Latency = max(RANK1 Attention, RANK0 Attention) + RANK0 Dispatch + GroupGemm + RANK0 Combine?
(3) In the timeline of actual LLM inference using torchprofiler, the blank areas shown in Figures 1 and 2 do not appear. Because within Dispatch and Combine, immediately after executing Attention/GroupGemm, the Dispatch/Combine kernel is launched, and the kernel is blocked internally due to not receiving data from other RANKs.
(4) If two micro-batches are enabled within RANK, will these two phenomena still occur? Will they become more severe or mitigated?
Thank you for your serious and detailed explanation! Appreciate any insights—thanks for your great work on this project!
For example, the only wait-data-arrival of dispatch is here: https://github.com/deepseek-ai/DeepEP/blob/main/csrc/kernels/internode_ll.cu#L492.
How does a RANK know how many inputs it should receive from other RANKs? Does this require an operation similar to all-reduce (which requires synchronization among all RANKs)?
Can you contact via WeChat (LyricZ_THU)?
What changes will occur in the end-to-end latency of each RANK? Can it be estimated as max(Dispatch latency) + Expert Group Gemm latency + max(Combine latency)?
In such an ideal scenario (assuming attention is also same), all the lantency should be stable and same. No changes, the time shoule be dispatch latency (data amount / bandwidth) + gemm latency + combine latency (data amount / bandwidth).
Does this imply the end-to-end time of RANK1/RANK2 should be RANK2 Dispatch latency + Expert Group Gemm latency + RANK1 Combine latency?
Yes, you are right.
Thank you very much for your detailed answer! I would like to confirm whether my understanding is correct. I am analyzing the impact of unbalanced batch sizes (affecting dispatch/combine latency) on end-to-end latency. This is very important to me. To make it easier to understand, I have drawn two diagrams to illustrate my question. @LyricZhao
(1) Batch Size Unbalance
(2) Attention Unbalance, Batch Size Balance
Based on Attention DP and Expert EP without enabling two micro-batch settings, I have four questions:
(1) When attentions across different RANKs are balanced but Batch Sizes are unbalanced, is the End-to-End Latency calculated as shown in Figure 1: End_to_End Latency = Attention + max(RANK1 Dispatch, RANK0 Dispatch) + GroupGemm + max(RANK1 Combine, RANK0 Combine)?
(2) When attentions across different RANKs are unbalanced but Batch Sizes are balanced, is the End-to-End Latency calculated as shown in Figure 2: End_to_End Latency = max(RANK1 Attention, RANK0 Attention) + RANK0 Dispatch + GroupGemm + RANK0 Combine?
(3) In the timeline of actual LLM inference using torchprofiler, the blank areas shown in Figures 1 and 2 do not appear. Because within Dispatch and Combine, immediately after executing Attention/GroupGemm, the Dispatch/Combine kernel is launched, and the kernel is blocked internally due to not receiving data from other RANKs.
(4) If two micro-batches are enabled within RANK, will these two phenomena still occur? Will they become more severe or mitigated?
Thank you for your serious and detailed explanation! Appreciate any insights—thanks for your great work on this project!
I find the same situation. I wonder can the waiting time of dispatch/combine be eliminated by a barrier, or overlaps? But it seems that it is designed as a no bidirectional synchronization operation, not should additional barrier can reduce the latency.