torchtitan icon indicating copy to clipboard operation
torchtitan copied to clipboard

Why is the ep mesh derived from a factoring of the dp mesh, instead of its own dimension?

Open man2machine opened this issue 2 months ago • 5 comments

I see that the data parallel shard dimension is factored into two dimensions, dp_shard_mod_ep and dp_shard_in_ep.

The experts use dp_shard_mod_ep submesh for FSDP while the rest of the blocks use the regular dp_shard_cp submesh. Why can't the experts use FSDP on the regular dp_mesh? The reason for this is unclear after reading the code. If only expert parallelism is used without data parallel or if the data parallel size is less than expert parallel, then the dp_shard_mod_ep dimension size would be 0, which doesn't make sense.

Furthermore, the ep submesh is not actually a bona fide actual dimension, but rather a combination of dp_shard_in_ep, cp and tp. Why can't ep be its own dimension? Currently ep is like some weird factored submesh of dp_shard instead of being its own dimension, and I don't understand why.

I understand the combining of various mesh dimensions into dp_shard_cp is used to limit those dimensions to a 1D mesh as FSDP accepts a 1D mesh and HSDP a 2D mesh.

But why can't the mesh dims be for example:

(assuming cp = 1, tp = 1, etp = 1) world mesh: ['pp', 'dp_replicate', 'dp_shard', 'ep', 'cp', 'tp'] dp_shard mesh: ['dp_shard'] (not flattening of ['dp_shard_in_ep', 'dp_shard_mod_ep'] ep mesh: ['ep'] (not 'dp_shard_in_ep')

Sorry for all the questions I'm just pretty confused as to whats going on. The most important question is why does dp_shard need to be factored into two dimensions? I also think the ._flatten() function should be exposed publicly if so many places use that function.

man2machine avatar Nov 01 '25 02:11 man2machine

The confusion is legit, lol.

The reason we wrote code this way was due to limitations in DeviceMesh capabilities. We only had _flatten so had to create "atomic" meshes and build up the top meshes. Besides, to keep the complexity from exploding, we made the assumptions satisfying common use cases, including "ep needs to use part of dp_shard and all cp, and maybe tp", "dp_shard and cp together have to be larger than ep", etc.

We understood the limitation so have reinvented DeviceMesh. The new integration code is WIP https://github.com/pytorch/torchtitan/pull/1660. See if you like it.

tianyu-l avatar Nov 01 '25 18:11 tianyu-l

I took a look at new the code for parallel dims, as well as the _flatten and _unflatten function in nightly. As I understand mesh._flatten() is basically like a named tensor.flatten() that you have in Tensor and mesh._unflatten(i, dims, names) is like a named tensor[i].view(*shape) that you have in Tensor.

It definitely is cleaner code, but there are some important questions I have:

  • If I'm using PyTorch stable, can I just create a new DeviceMesh class, instead of using unflatten, with no issues to have a similar effect? Will this mess up things in terms of the process group initialization or not?
  • This still doesn't explain why the expert parallel dim needs be a factored fsdp dim.

Please correct me if my understand this wrong:

If for example, pp=1, dp_replicate=1, cp=1, tp=1 and dp_shard=16 and ep=16 (with 16 experts), this means that in torchtitan's setup expert parallel and FSDP are performed on the same axis. So each GPU gets local_batch samples and holds one sharded expert, with a global back size of 16 * local_batch. If the ep was not factored from dp_shard and was its own dimension and for example if dp_shard=4 and ep=4, this would mean each GPU gets local_batch samples and holds 4 shared experts, with a global batch size of 4 * local_batch.

With this example it is a bit more clear why EP may be better off being part of the FSDP axis, since in the separate axis case the global batch size is smaller, the parameter count per GPU is greater, and the communication latency is probably not decreased by much I am assuming to warrant those changes.

And therefore technically there are many other combinations of overlapping, where for example EP is performed on the same axis as dp_replicate. The only real restrictions is that pp must be its own axis since it removes parts of the model, and cp must also be its own axis I am assuming. Technically tp could be factored from fsdp as well (and ep is a type of tp anyways). Please let me know if this understanding is correct.

As a side note, how users map the parallel dimension to the exact node/device topology needs to be more explicit. This is a problem in both the old and new parallel dims code. For example, consider a setup with 4 machines, each with 4 GPUs each. And you have dp=4, pp=4. How would users ensure that the pp=4 dimension is mapped to the intra-node axis rather than the inter-node axis? I get that torchrun initializes local ranks in a certain order, and as long as you specify the mesh dims in an order that aligns with what torchrun sets you should be able to do what you want, but it would definitely help for this to be more explicit and not non-obvious in the docs.

man2machine avatar Nov 03 '25 08:11 man2machine

cc @fegin @fduwjj on user feedback.

tianyu-l avatar Nov 03 '25 08:11 tianyu-l

@man2machine

DeviceMesh is not designed to decide how researchers/users parallelize a model. Instead, researchers/users decide how to parallelize the model and use DeviceMesh to simplify the connectivity representation in the code. With this idea in mind, let's go through each question.

If I'm using PyTorch stable, can I just create a new DeviceMesh class, instead of using unflatten, with no issues to have a similar effect? Will this mess up things in terms of the process group initialization or not?

Yes, but the code can be error-prune and may have duplicated PG creation. Let's say you have a parallelism combination for certain layers [dp, ep, tp]. If you don't want to use unflatten/flatten/slicing and opt to use DeviceMesh each time, you may have duplicated PG creation, if you directly create the whole [dp, ep, tp]. Because dp may be created somewhere else. If you decide to create only [ep], then you need to manually figure out the group ranks participate in ep for each rank.

So the answer is yes, but you may want. to just use process group directly if you always use DeviceMesh directly.

This still doesn't explain why the expert parallel dim needs be a factored fsdp dim.

I don't quite understand this question, is this more about how to write EP with DeviceMesh or why we design EP like this way? These are two different questions. The formal one, we already have how to do EP but we need DeviceMesh to better supports it. The later one is orthogonal to DeviceMesh, but what's the best EP/DP design.

As a side note, how users map the parallel dimension to the exact node/device topology needs to be more explicit. This is a problem in both the old and new parallel dims code. For example, consider a setup with 4 machines, each with 4 GPUs each. And you have dp=4, pp=4. How would users ensure that the pp=4 dimension is mapped to the intra-node axis rather than the inter-node axis? I get that torchrun initializes local ranks in a certain order, and as long as you specify the mesh dims in an order that aligns with what torchrun sets you should be able to do what you want, but it would definitely help for this to be more explicit and not non-obvious in the docs.

This is related to the first question. If you always use DeviceMesh to create a new dim or a new mesh, then you will need to figure out the ranks mapping. But you start from the world mesh which contains all the ranks, world=[0, 1, 2, ... 15], in your case, then you can just use world.unflatten([4, 4], ["dp", "tp"]). This will result in a mesh like

[ [0, 1, 2, 3],
   ....
  [12, 13, 14, 15]]

And you know that tp works on the intra node dim and dp works on the inter node dim.

fegin avatar Nov 03 '25 19:11 fegin

@fegin

Regarding creating a new DeviceMesh each time instead of using mesh._unflatten() which is not in PyTorch stable yet, I saw in the code that mesh._get_or_create_default_group() which is called during initialization does a torch.distributed.is_initialized() check. So wouldn't it be fine then to just create a new DeviceMesh each time? I would prefer unflatten, but its not stable yet. As long as you call torch.distributed.init_process_group() before DeviceMesh initialization, GroupMember.WORLD should always be set and thus torch.distributed.is_initialized() will return True. I don't see the problem here. If you can explain that would be great.

The other questions are not about how to do EP but about why you designed EP in that particular way. It's unrelated to exact DeviceMesh implementation/usage, and was my original question on why the EP mesh was derived from the DP mesh. Why is EP performed on the same (factored) axis as DP? Why not TP also on the same axis as DP? I just didn't understand why these specific decisions were chosen for the best EP/DP/TP design. As I understand PP and CP are sort of their own thing and can't really be overlapped with the other parallelism styles. But EP/DP/TP can be overlapped or kept separate in many different possible configurations.

man2machine avatar Nov 04 '25 02:11 man2machine

Why is EP performed on the same (factored) axis as DP? Why not TP also on the same axis as DP? I just didn't understand why these specific decisions were chosen for the best EP/DP/TP design. As I understand PP and CP are sort of their own thing and can't really be overlapped with the other parallelism styles. But EP/DP/TP can be overlapped or kept separate in many different possible configurations.

I'm also curious about this and not quite clear on the answer from this thread... cc @tianyu-l could you elaborate?

danielvegamyhre avatar Dec 02 '25 00:12 danielvegamyhre

Why is EP performed on the same (factored) axis as DP? Why not TP also on the same axis as DP? I just didn't understand why these specific decisions were chosen for the best EP/DP/TP design.

Let's separate discussion of MoE layers and dense layers.

  • A dense layer can't use EP, it can use DP / TP / CP / PP or any combination of them
  • An MoE layer can use DP / TP / CP / PP / EP or any combination of them

So "Why not TP also on the same axis as DP?" for dense layer itself is not a legit question. If you are asking: can MoE TP reuse dense layer's DP, technically this is doable, but I guess there's not much benefit to use TP for MoE but no-TP / less-TP for dense for mainstream model arch. To fully verify this one could do some back-of-envelope estimation of all the computation/communication tradeoffs and run some experiments?

tianyu-l avatar Dec 02 '25 01:12 tianyu-l