streaming icon indicating copy to clipboard operation
streaming copied to clipboard

Pipeline Parallelism (Supported? How to?)

Open casper-hansen opened this issue 1 year ago • 4 comments

🚀 Feature Request

Supporting TP and SP seems quite easy to do with the `replication parameter:

replication = tp * sp

I have tried various ways to enable PP without success (unexpected high loss). I tried adding pp into the equation when computing replication and num_canonical_nodes, but I cannot get it to function normally because I get an unexpected high loss.

Motivation

I want to use the mosaicml streaming library with 4D parallel. Specifically, I rely on TorchTitan as my training tool and have simply swapped in the mosaicml streaming library by modifying the StreamingTextDataset implementation from LLM Foundry.

casper-hansen avatar Nov 14 '24 09:11 casper-hansen

we can look into this more in detail, meanwhile, have you tried using mosaicml/composer though for training? Are there specific features you are relying on in Torchtitan?

ethantang-db avatar Nov 15 '24 19:11 ethantang-db

I would really appreciate if you could look into it! TorchTitan uses torch.distributed.pipelining, most of which is only available from 2.5.0 or in nightly builds.

There are many key features like FSDP2, 4D parallelism, FP8, and torch.compile that makes LLaMa models scale well in pretraining. You also get full control over the training loop which is desirable if you want to experiment.

casper-hansen avatar Nov 15 '24 20:11 casper-hansen

@casper-hansen So StreamingDataset's replication argument assumes that the ranks that have replicated samples are in contiguous blocks of global rank indices. Concretely, suppose on 16 GPUs, I have a replication factor of 2. Then StreamingDataset will replicate the same samples on GPU ranks 0 and 1, 2 and 3, 4 and 5, and so on. In the 4D parallelism case, you likely have ranks that are not contiguous, but still want to replicate samples over these ranks (as in, using the previous example, you may want GPU ranks 0, 1, 8, and 9 to see the same samples).

We currently enable replication through the World object's replicate function (called here) which is used to set the correct global node and rank indices to construct the sample partition over and retrieve samples. If you want to try enabling 4D parallelism yourself, I would take a look at the replicate function here and allow it to create a new World object with the right information according to your sharding & parallelism strategy.

snarayan21 avatar Jan 04 '25 05:01 snarayan21

Would be great to integrate the new DeviceMesh abstraction from pytorch.

cassanof avatar Jan 13 '25 07:01 cassanof