tutel icon indicating copy to clipboard operation
tutel copied to clipboard

[Question] How does the gate coordinate across ranks in expert parallelism?

Open wangyaojlu opened this issue 8 months ago • 3 comments

Hi, I'm trying to understand how the Gate module works in Tutel's MoE implementation.

Since each rank only maintains a subset of experts (num_experts_per_device), but the Gate output seems to be shaped across the total number of experts globally, I'm curious about how gates work across different ranks.

Specifically:

Does each rank need to maintain the same gate output?

Is there any communication happening after the gate, such as All-to-All, to route tokens to the correct experts?

The reason I'm asking is that after training the model, I noticed that the Gate parameters are not identical across different ranks. I would like to ask whether this behavior is expected or indicates a problem.

wangyaojlu avatar Apr 27 '25 07:04 wangyaojlu

Q: Does each rank need to maintain the same gate output? Each rank's inputs are "local batch of data", and their gating output will be also their "local gating selection of experts."

What do you mean by "the Gate parameters are not identical across different ranks"? Is it "gating layer's parameter" or "expert layers' parameter"?

ghostplant avatar Apr 27 '25 17:04 ghostplant

Q: Does each rank need to maintain the same gate output? Each rank's inputs are "local batch of data", and their gating output will be also their "local gating selection of experts."

What do you mean by "the Gate parameters are not identical across different ranks"? Is it "gating layer's parameter" or "expert layers' parameter"?

Thank you for your reply. I meant that the "gating layer's parameters" are different across ranks.

Assuming I have 4 GPUs, with each rank holding 1 expert (num_experts_per_device=1), resulting in a total of 4 experts, the gate parameters on each rank are of size [4, embedding_dim], which suggests that each rank holds the full set of gate parameters. However, I’ve noticed that the "gating layer's parameters" are different across ranks.

I wanted to confirm whether this behavior is normal or if it indicates an issue.

wangyaojlu avatar Apr 28 '25 01:04 wangyaojlu

Nop, they should be the same if you enable standard DDP, as they are considered as shared parameters.

However, if you have FSDP/Zero2/Zero3 which is enabled by other frameworks (e.g. DeepSpeed), then these parameters maybe improperly shared, which may result in not identical values.

Another possibility is, when you maintain shared parameter updates manually and you forget to have all reduce applied on gate parameters, then they won't be synchronized to be identical values.

Do these cases hit the situation of yours?

ghostplant avatar Apr 28 '25 03:04 ghostplant