flux icon indicating copy to clipboard operation
flux copied to clipboard

[QUESTION] E2E Overlap: Flux design

Open rajagond opened this issue 9 months ago • 9 comments

Hi,

I want to understand how you implemented overlap in e2e.

Let’s take all-reduce after MLP as an example. You break the all-reduce into reduce-scatter and all-gather. Reduce-scatter is overlapped with MLP2 of the same layer. What about all-gather?

Similarly, for the all-reduce after attention, you overlap all-gather with MLP1. What about reduce-scatter? Are you overlapping this with post-projection, or is it not overlapped? If it is the former, isn’t post-projection too small?

rajagond avatar Mar 30 '25 05:03 rajagond

@wenlei-bao @houqi

rajagond avatar Apr 01 '25 09:04 rajagond

for dense, with sequence parallel, AR +LN is converted into RS + LN + AG.

in which AG for AllGather, LN for LayerNorm, RS for reduce_scatter.

for the FFN part, AG can be fused with FFN 0 GEMM. RS can be fused with FFN 1 GEMM.

also you can do the same in ATTN part.

for MOE, it's just like the dense.

houqi avatar Apr 02 '25 08:04 houqi

Image

In the end-to-end (E2E) implementation, you have used Tensor Parallelism, correct?
Sorry, I’m a bit confused.

Dense Baseline Workflow

pre-projection → Attention → post-projection → All Reduce → LayerNorm → FFN 0-GEMM → activation → FFN 1-GEMM → All Reduce → LayerNorm

Overlapped Tensor Parallelism Workflow

pre-projection → Attention → post-projection → Reduce Scatter → LayerNorm → Fused(AllGather, FFN 0-GEMM) → activation → Fused(FFN 1-GEMM, Reduce Scatter) → LayerNorm → AllGather

Questions
  • How are you handling Reduce Scatter (RS) after post-projection?
  • How are you handling AllGather (AG) at the end?

rajagond avatar Apr 02 '25 09:04 rajagond

In the end-to-end (E2E) implementation, you have used Tensor Parallelism, correct?

yes

How are you handling Reduce Scatter (RS) after post-projection?

post-projection is a GEMM too. usually post-projection is H * H, compared with FFN-0 H * 4H or FFN-1 4H * H, it's small and fuse post-projection and the following RS may not have any benifit. so you have to test yourself and check if it works or not.

How are you handling AllGather (AG) at the end?

AG can be fused with the following pre-projection from next layer.

houqi avatar Apr 02 '25 09:04 houqi

I have read articles about Flux and noticed that the paper mentions a ​​TP+SP approach in Transformer, not pure TP. To confirm: During the ​​decoding phase of the inference stage​​, is the same TP+SP scheme applied, but with splitting performed along the batch dimension instead of the sequence length (since the sequence length is 1 in the decoding phase)? Could you please confirm if my understanding is correct? @houqi

shenyt-sanshui avatar Apr 18 '25 08:04 shenyt-sanshui

I have read articles about Flux and noticed that the paper mentions a ​​TP+SP approach in Transformer, not pure TP. To confirm: During the ​​decoding phase of the inference stage​​, is the same TP+SP scheme applied, but with splitting performed along the batch dimension instead of the sequence length (since the sequence length is 1 in the decoding phase)? Could you please confirm if my understanding is correct? @houqi

@wenlei-bao

houqi avatar Apr 18 '25 23:04 houqi

In the end-to-end (E2E) implementation, you have used Tensor Parallelism, correct? Sorry, I’m a bit confused.

Dense Baseline Workflow

pre-projection → Attention → post-projection → All Reduce → LayerNorm → FFN 0-GEMM → activation → FFN 1-GEMM → All Reduce → LayerNorm

Overlapped Tensor Parallelism Workflow

pre-projection → Attention → post-projection → Reduce Scatter → LayerNorm → Fused(AllGather, FFN 0-GEMM) → activation → Fused(FFN 1-GEMM, Reduce Scatter) → LayerNorm → AllGather

Questions
  • How are you handling Reduce Scatter (RS) after post-projection?
  • How are you handling AllGather (AG) at the end?

@rajagond If you familiar with loop unrolling, it is just like unroll by factor 2, and think about this iter with next iter at same time.

wenlei-bao avatar Apr 23 '25 23:04 wenlei-bao

I have read articles about Flux and noticed that the paper mentions a ​​TP+SP approach in Transformer, not pure TP. To confirm: During the ​​decoding phase of the inference stage​​, is the same TP+SP scheme applied, but with splitting performed along the batch dimension instead of the sequence length (since the sequence length is 1 in the decoding phase)? Could you please confirm if my understanding is correct? @houqi

@Linus-Voss which part are you referring to? I don't think we release the TP/SP support yet.

wenlei-bao avatar Apr 23 '25 23:04 wenlei-bao

I have read articles about Flux and noticed that the paper mentions a ​​TP+SP approach in Transformer, not pure TP. To confirm: During the ​​decoding phase of the inference stage​​, is the same TP+SP scheme applied, but with splitting performed along the batch dimension instead of the sequence length (since the sequence length is 1 in the decoding phase)? Could you please confirm if my understanding is correct? @houqi

Can padding the matrix solve the problem here? @Linus-Voss

zfy3000163 avatar Jun 09 '25 07:06 zfy3000163