[torchtitan][replicate] experimenting new replicate integration with torchtitan
Summary: During this experiment to integrate the new replicate function into torchtitan, I used https://github.com/pytorch/pytorch/pull/162021, which has not been landed. However, since this is more about making replicate more efficient rather than changing replicate's core code, https://github.com/pytorch/pytorch/pull/160135, which has landed, should be fine. https://github.com/pytorch/pytorch/pull/160133 is the last time replicate_with_fsdp.py and its replicate api were touched.
In order to enable the new replicate, which uses a 2D device mesh (since it is a specialized version of HSDP), I changed the parallelism code to include dp_shard dim = 1 only if dp_replicate > 1, and created device mesh that I pass down in apply_ddp.
The numeric tests for tp + replicate and pp + replicate can be seen below. In order to ensure that they worked, I also compared them with HSDP (n, 1) (replicate, shard).
https://fburl.com/mlhub/5k9v43w3
Test Case
- CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh (set replicate to 8)
Expected output of this experiment should be something like: [rank0]:[titan] 2025-09-15 17:38:26,676 - root - INFO - Starting job: Llama 3 debug training [rank0]:[titan] 2025-09-15 17:38:29,094 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config [rank0]:[titan] 2025-09-15 17:38:29,097 - root - INFO - Building 2-D device mesh with ['dp_replicate', 'dp_shard'], [8, 1] [rank0]:[titan] 2025-09-15 17:38:29,104 - root - INFO - [GC] Initial GC collection 0.00 seconds [rank0]:NCCL version 2.27.5+cuda12.6 [rank0]:[titan] 2025-09-15 17:38:35,439 - root - INFO - Loading tokenizer from tokenizer.json [rank0]:[titan] 2025-09-15 17:38:35,441 - root - INFO - Preparing c4_test dataset from tests/assets/c4_test [rank0]:[titan] 2025-09-15 17:38:35,894 - root - INFO - Building llama3 debugmodel with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=256, n_layers=6, n_heads=16, n_kv_heads=None, vocab_size=2000, multiple_of=256, ffn_dim_multiplier=None, norm_eps=1e-05, rope_theta=500000, max_seq_len=2048, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=0) [rank0]:[titan] 2025-09-15 17:38:35,931 - root - INFO - CUDA capacity: NVIDIA H100 with 94.99GiB memory [rank0]:[titan] 2025-09-15 17:38:35,950 - root - INFO - Model llama3 debugmodel size: 6,139,136 total parameters [rank0]:[titan] 2025-09-15 17:38:35,951 - root - INFO - Applied selective activation checkpointing to the model [rank0]:[titan] 2025-09-15 17:38:35,972 - root - INFO - Applied DDP to the model [rank0]:[titan] 2025-09-15 17:38:36,153 - root - INFO - Peak FLOPS used for computing MFU: 9.890e+14 [rank0]:[titan] 2025-09-15 17:38:36,153 - root - INFO - CUDA memory usage for model: 0.04GiB(0.04%) [rank0]:[titan] 2025-09-15 17:38:36,154 - root - WARNING - model.safetensors.index.json not found at hf_assets_path: ./tests/assets/tokenizer/model.safetensors.index.json. Defaulting to saving a single safetensors file if checkpoint is saved in HF format [rank0]:[titan] 2025-09-15 17:38:36,154 - root - INFO - Mixed precision training is handled by AMP [rank0]:[titan] 2025-09-15 17:38:36,154 - root - INFO - Trainer is initialized with local batch size 8, global batch size 64, gradient accumulation steps 1, sequence length 2048, total steps 10 (warmup 2)
Stack from ghstack (oldest at bottom):
- -> #1714
In this case, Mixed precision of replicate_with_fsdp should be handled by fully_shard instead of AMP. This means that we need to modify torchtitan/distributed/utils.py/maybe_enable_amp() to accommodate replicate_with_fsdp .
By the way, DistributedDataParallel has experimentally supported native mixed precision, similar to MixedPrecisionPolicy of FSDP2. This means that perhaps torchtitan can remove torchtitan/distributed/utils.py/maybe_enable_amp() completely. See at DDP native mixed precision #92882.
cc @weifengpy @tianyu-l
@EquationWalker
In this case, Mixed precision of replicate_with_fsdp should be handled by fully_shard instead of AMP. This means that we need to modify torchtitan/distributed/utils.py/maybe_enable_amp() to accommodate replicate_with_fsdp .
Great point! @anshul-si Let's accommodate.
In this case, Mixed precision of
replicate_with_fsdpshould be handled by fully_shard instead of AMP. This means that we need to modifytorchtitan/distributed/utils.py/maybe_enable_amp()to accommodatereplicate_with_fsdp. By the way,DistributedDataParallelhas experimentally supported native mixed precision, similar toMixedPrecisionPolicyof FSDP2.
@EquationWalker good catch!
my request changes is mainly on 2d mesh. we should target 1d mesh for landing. it's a user contract in public facing api
I think the use of 2D mesh has something to do with the FSDPParamGroup user contract. When passing a 2D mesh, FSDPParamGroup treats it as an HSDP and then shard parameters in the second dimension and replicate parameters in the first dimension. If you pass a 1D Mesh, FSDPParamGroup will shard parameters on this mesh instead of replicating them.
pytorch/issues#159013 mentioned adding reshape operation to mesh, but it seems that pytorch has not implemented it.
One solution would be to recreate a new 2D mesh (N,1) using the 1D mesh shape information (N,) inside apply_ddp, but this would create new communication groups, and I'm not sure if this is an expensive operation.
If replicate_with_fsdp could internally convert any 2D mesh shape (N,M) to (N*M,1), and any 1D mesh shape (N,) to (N,1), perhaps we can use fully_shard and replicate_with_fsdp in combination.
cc @anshul-si @tianyu-l
The 1D mesh as the input has been in the roadmap but there are two issues.
-
We need DeviceMesh unflatten support support which the PR is being reviewed but is not landed yet. But this is just a WIP.
-
iirc, there is a concern about the input mesh being 1D but
named_parameters()andstate_dict()will actually return 2D DeviceMesh, which can cause inconsistent views for users.
@anshul-si @mori360 Please correct me if I'm wrong for the second issue.
2. iirc, there is a concern about the input mesh being 1D but
named_parameters()andstate_dict()will actually return 2D DeviceMesh, which can cause inconsistent views for users.
I was thinking named_parameters() and state_dict() returning 1D
We need DeviceMesh unflatten support support which the PR is being reviewed but is not landed yet
which PR? I don't mind creating a new 2D mesh inside replicate to unblock for now, but the user contract has to be right on day 1
- input: 1D mesh
- output: model.parameters() and model.state_dict() being 1D