[rfc] getting rid of seed-checkpoint for Pipeline Parallelism
Currently PP uses a 'seed checkpoint' for initialization becuase (1) its nice to initialize the model the same way (same RNG) as non-PP for loss-comparison purposes, (2) the whole model may not be able to fit on GPU memory (or even CPU memory when accounting for 8 copies, one per GPU).
The downside is the seed checkpoint creation process takes an extra step that is slower as the model grows larger, which is not a good user experience.
Step 1: Make model.init_weights 'pipeline friendly'
Currently, if we call init_weights on a model-chunk after meta-init and pipeline splitting, we'd crash. init_weights expects all the paramers of the model, but in pipeline splitting we delete some.
A pretty simple fix is to modify Transformer.init_weights to respect the possibility that self.tok_embeddings is None or self.output is none (skip initializing them if so). Layers should already be OK since the loop will only hit layers that PP did not delete. Now our initializer runs without crashing.
So far, we only unblocked basic functionality (running for CI, checking WPS, peak memory), but every PP stage will use the same RNG state so convergence should be affected.
Note: this approach does not take finer-grained splitting into account. If users wanted to put "half a transformer layer" on a pipeline stage, additional work would be needed to make initialization work.
Step 2: Fix the RNG problem
Option 1: A quick thing to try is to add one function to torch.pipelining that draws PP_Ranks-1 random integers and broadcasts them to the nonzero PP ranks to use to set their own seeds. Now every PP rank starts out with a different seed and none of the layers get the same initial value. At this point we probably should converge OK, even though we'd still not exactly match a non-PP initialization.
Option 2: A more advanced version would be a function in torch.pipelining.Schedule class that accepts the model init_weights function as argument, already has pointers to all the local stages and their model chunks and knows the pipeline order (looped, interleaved, V-shaped, etc). It would sequentially initialize one layer at a time, starting from rank0 chunk0. Then, it would extract the current RNG seeds locally and send them to the rank that holds chunk 1, which would set its RNG state to match the states taken from rank0, initialize chunk 1, and extract its updated RNG states to send on to the next rank, so on until complete.
cc @H-Huang @wanchaol @tianyu-l @lessw2020
Just a random thought. Do we need torch.pipelining.Schedule to do step 2 option2? Can we first traverse and bookkeep the module initialization order (before pipeline is applied) and then replay the order during init_weight? There is no memory issue because we just need to know the order without materializing the memory.
During init_weights, if a rank does not participate the initialization of the current module (due to the module being deleted by PP), it simply does nothing. After each module initialization, all ranks perform allgather RNG status. Ranks that just participated in the module initialization should send the RNG status, other ranks should send None status (or some tensors that indicate they do not participate in the initialization). This way, we can guarantee that every rank get the correct RNG seed without coupling with PP's implementation.
Just a random thought. Do we need torch.pipelining.Schedule to do step 2 option2?
The reason for this is that there are 2 common mappings- Loop and V.
Loop: (e.g. DFS) rank0: stage0, stage2 rank1: stage1, stage3
V: (e.g. ZeroBubble V) rank0: stage0, stage3 rank1: stage1, stage2
Can we first traverse and bookkeep the module initialization order (before pipeline is applied) and then replay the order during init_weight
I'm not even considering the ordering of modules being nontrivial. I just assume its straightforward ordering (iterating the layers dict). But i suppose we do not have to keep track becuase we delegate. (1) we have split the whole model into chunks, and we need to initialize the chunks in sequential order. (2) within a chunk, the user-updated init_weights() needs to do the right ordering within its own chunk, but we don't care about that at runtime
During init_weights, if a rank does not participate the initialization of the current module (due to the module being deleted by PP), it simply does nothing. After each module initialization, all ranks perform allgather RNG status. Ranks that just participated in the module initialization should send the RNG status, other ranks should send None status (or some tensors that indicate they do not participate in the initialization). This way, we can guarantee that every rank get the correct RNG seed without coupling with PP's implementation.
your proposal would work too. You suggest that perform initialization in a way like this:
for stage_idx in range(num_stages):
if stage_idx in local_stages:
local_stages[stage_idx].model.init_weights()
perform_rng_allgather()
I'm not even considering the ordering of modules being nontrivial. I just assume its straightforward ordering (iterating the layers dict). But i suppose we do not have to keep track becuase we delegate. (1) we have split the whole model into chunks, and we need to initialize the chunks in sequential order. (2) within a chunk, the user-updated init_weights() needs to do the right ordering within its own chunk, but we don't care about that at runtime
But if we perform the fake initialization with the meta device just like doing the tracing or compilation, we should be able to get the order.
your proposal would work too. You suggest that perform initialization in a way like this:
for stage_idx in range(num_stages): if stage_idx in local_stages: local_stages[stage_idx].model.init_weights() perform_rng_allgather()
Yes, similar idea but tracing the order first is required to make this idea general. But it is should be okay if we just have layers like the current llama2/3.
oops forgot about this RFC. opened another one on pytorch side for RNG-state management for torch.pipelining. perhaps reassuringly, i rediscovered the same 2 options as in this proposal 😅 .
https://github.com/pytorch/pytorch/issues/139304
closing this RFC now after landing various enhancements to DTensor RNG and updating TorchTitan's RNG configuration (#689).
- Seed checkpoint is not deleted, but it is strictly optional now, should only be used in cases where someone wants to compare the same model init across different parallel configs
- We deemed it practically impossible to make TP/DP parallelisms behave in a way that model initialization would match single-gpu, due to how the RNG works.
- There was a PR for making pipeline parallelism sync RNG states to match single-gpu behavior during model init, and it was scrapped since it would only work for pure PP, not PP+(TP/DP) due to issue 2.
- DTensor's RNG infra got updated to (a) stop using TensorParallelRNGTracker since it was not correct in all cases, (b) avoid doing a world-broadcast of the RNG seed if someone calls manual_seed API such as Torchtitan in (5). https://github.com/pytorch/pytorch/pull/141220 https://github.com/pytorch/pytorch/pull/141223
- Torchtitan has been updated (#689) to configure the RNG seeds appropriately so that all the supported parallelisms (and different combinations of PP/TP/DP etc) will correctly initialize the model. Namely, each PP stage would use different RNG seeds, and configure DTensor's RNG within that that PP stage.