[QUESTION] E2E Overlap: Flux design
Hi,
I want to understand how you implemented overlap in e2e.
Let’s take all-reduce after MLP as an example. You break the all-reduce into reduce-scatter and all-gather. Reduce-scatter is overlapped with MLP2 of the same layer. What about all-gather?
Similarly, for the all-reduce after attention, you overlap all-gather with MLP1. What about reduce-scatter? Are you overlapping this with post-projection, or is it not overlapped? If it is the former, isn’t post-projection too small?
@wenlei-bao @houqi
for dense, with sequence parallel, AR +LN is converted into RS + LN + AG.
in which AG for AllGather, LN for LayerNorm, RS for reduce_scatter.
for the FFN part, AG can be fused with FFN 0 GEMM. RS can be fused with FFN 1 GEMM.
also you can do the same in ATTN part.
for MOE, it's just like the dense.
In the end-to-end (E2E) implementation, you have used Tensor Parallelism, correct?
Sorry, I’m a bit confused.
Dense Baseline Workflow
pre-projection → Attention → post-projection → All Reduce → LayerNorm → FFN 0-GEMM → activation → FFN 1-GEMM → All Reduce → LayerNorm
Overlapped Tensor Parallelism Workflow
pre-projection → Attention → post-projection → Reduce Scatter → LayerNorm → Fused(AllGather, FFN 0-GEMM) → activation → Fused(FFN 1-GEMM, Reduce Scatter) → LayerNorm → AllGather
Questions
- How are you handling Reduce Scatter (RS) after post-projection?
- How are you handling AllGather (AG) at the end?
In the end-to-end (E2E) implementation, you have used Tensor Parallelism, correct?
yes
How are you handling Reduce Scatter (RS) after post-projection?
post-projection is a GEMM too. usually post-projection is H * H, compared with FFN-0 H * 4H or FFN-1 4H * H, it's small and fuse post-projection and the following RS may not have any benifit. so you have to test yourself and check if it works or not.
How are you handling AllGather (AG) at the end?
AG can be fused with the following pre-projection from next layer.
I have read articles about Flux and noticed that the paper mentions a TP+SP approach in Transformer, not pure TP. To confirm: During the decoding phase of the inference stage, is the same TP+SP scheme applied, but with splitting performed along the batch dimension instead of the sequence length (since the sequence length is 1 in the decoding phase)? Could you please confirm if my understanding is correct? @houqi
I have read articles about Flux and noticed that the paper mentions a
TP+SPapproach in Transformer, not pureTP. To confirm: During the decoding phase of the inference stage, is the sameTP+SPscheme applied, but with splitting performed along the batch dimension instead of the sequence length (since the sequence length is 1 in the decoding phase)? Could you please confirm if my understanding is correct? @houqi
@wenlei-bao
In the end-to-end (E2E) implementation, you have used Tensor Parallelism, correct? Sorry, I’m a bit confused.
Dense Baseline Workflow
pre-projection → Attention → post-projection → All Reduce → LayerNorm → FFN 0-GEMM → activation → FFN 1-GEMM → All Reduce → LayerNorm
Overlapped Tensor Parallelism Workflow
pre-projection → Attention → post-projection → Reduce Scatter → LayerNorm → Fused(AllGather, FFN 0-GEMM) → activation → Fused(FFN 1-GEMM, Reduce Scatter) → LayerNorm → AllGather
Questions
- How are you handling Reduce Scatter (RS) after post-projection?
- How are you handling AllGather (AG) at the end?
@rajagond If you familiar with loop unrolling, it is just like unroll by factor 2, and think about this iter with next iter at same time.
I have read articles about Flux and noticed that the paper mentions a
TP+SPapproach in Transformer, not pureTP. To confirm: During the decoding phase of the inference stage, is the sameTP+SPscheme applied, but with splitting performed along the batch dimension instead of the sequence length (since the sequence length is 1 in the decoding phase)? Could you please confirm if my understanding is correct? @houqi
@Linus-Voss which part are you referring to? I don't think we release the TP/SP support yet.
I have read articles about Flux and noticed that the paper mentions a TP+SP approach in Transformer, not pure TP. To confirm: During the decoding phase of the inference stage, is the same TP+SP scheme applied, but with splitting performed along the batch dimension instead of the sequence length (since the sequence length is 1 in the decoding phase)? Could you please confirm if my understanding is correct? @houqi
Can padding the matrix solve the problem here? @Linus-Voss