torchtitan icon indicating copy to clipboard operation
torchtitan copied to clipboard

[torchtitan][replicate] experimenting new replicate integration with torchtitan

Open anshul-si opened this issue 3 months ago • 6 comments

Summary: During this experiment to integrate the new replicate function into torchtitan, I used https://github.com/pytorch/pytorch/pull/162021, which has not been landed. However, since this is more about making replicate more efficient rather than changing replicate's core code, https://github.com/pytorch/pytorch/pull/160135, which has landed, should be fine. https://github.com/pytorch/pytorch/pull/160133 is the last time replicate_with_fsdp.py and its replicate api were touched.

In order to enable the new replicate, which uses a 2D device mesh (since it is a specialized version of HSDP), I changed the parallelism code to include dp_shard dim = 1 only if dp_replicate > 1, and created device mesh that I pass down in apply_ddp.

The numeric tests for tp + replicate and pp + replicate can be seen below. In order to ensure that they worked, I also compared them with HSDP (n, 1) (replicate, shard).

image

https://fburl.com/mlhub/5k9v43w3

Test Case

  1. CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh (set replicate to 8)

Expected output of this experiment should be something like: [rank0]:[titan] 2025-09-15 17:38:26,676 - root - INFO - Starting job: Llama 3 debug training [rank0]:[titan] 2025-09-15 17:38:29,094 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config [rank0]:[titan] 2025-09-15 17:38:29,097 - root - INFO - Building 2-D device mesh with ['dp_replicate', 'dp_shard'], [8, 1] [rank0]:[titan] 2025-09-15 17:38:29,104 - root - INFO - [GC] Initial GC collection 0.00 seconds [rank0]:NCCL version 2.27.5+cuda12.6 [rank0]:[titan] 2025-09-15 17:38:35,439 - root - INFO - Loading tokenizer from tokenizer.json [rank0]:[titan] 2025-09-15 17:38:35,441 - root - INFO - Preparing c4_test dataset from tests/assets/c4_test [rank0]:[titan] 2025-09-15 17:38:35,894 - root - INFO - Building llama3 debugmodel with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=256, n_layers=6, n_heads=16, n_kv_heads=None, vocab_size=2000, multiple_of=256, ffn_dim_multiplier=None, norm_eps=1e-05, rope_theta=500000, max_seq_len=2048, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=0) [rank0]:[titan] 2025-09-15 17:38:35,931 - root - INFO - CUDA capacity: NVIDIA H100 with 94.99GiB memory [rank0]:[titan] 2025-09-15 17:38:35,950 - root - INFO - Model llama3 debugmodel size: 6,139,136 total parameters [rank0]:[titan] 2025-09-15 17:38:35,951 - root - INFO - Applied selective activation checkpointing to the model [rank0]:[titan] 2025-09-15 17:38:35,972 - root - INFO - Applied DDP to the model [rank0]:[titan] 2025-09-15 17:38:36,153 - root - INFO - Peak FLOPS used for computing MFU: 9.890e+14 [rank0]:[titan] 2025-09-15 17:38:36,153 - root - INFO - CUDA memory usage for model: 0.04GiB(0.04%) [rank0]:[titan] 2025-09-15 17:38:36,154 - root - WARNING - model.safetensors.index.json not found at hf_assets_path: ./tests/assets/tokenizer/model.safetensors.index.json. Defaulting to saving a single safetensors file if checkpoint is saved in HF format [rank0]:[titan] 2025-09-15 17:38:36,154 - root - INFO - Mixed precision training is handled by AMP [rank0]:[titan] 2025-09-15 17:38:36,154 - root - INFO - Trainer is initialized with local batch size 8, global batch size 64, gradient accumulation steps 1, sequence length 2048, total steps 10 (warmup 2)

Stack from ghstack (oldest at bottom):

  • -> #1714

anshul-si avatar Sep 15 '25 23:09 anshul-si

In this case, Mixed precision of replicate_with_fsdp should be handled by fully_shard instead of AMP. This means that we need to modify torchtitan/distributed/utils.py/maybe_enable_amp() to accommodate replicate_with_fsdp . By the way, DistributedDataParallel has experimentally supported native mixed precision, similar to MixedPrecisionPolicy of FSDP2. This means that perhaps torchtitan can remove torchtitan/distributed/utils.py/maybe_enable_amp() completely. See at DDP native mixed precision #92882. cc @weifengpy @tianyu-l

EquationWalker avatar Sep 24 '25 14:09 EquationWalker

@EquationWalker

In this case, Mixed precision of replicate_with_fsdp should be handled by fully_shard instead of AMP. This means that we need to modify torchtitan/distributed/utils.py/maybe_enable_amp() to accommodate replicate_with_fsdp .

Great point! @anshul-si Let's accommodate.

tianyu-l avatar Sep 24 '25 18:09 tianyu-l

In this case, Mixed precision of replicate_with_fsdp should be handled by fully_shard instead of AMP. This means that we need to modify torchtitan/distributed/utils.py/maybe_enable_amp() to accommodate replicate_with_fsdp . By the way, DistributedDataParallel has experimentally supported native mixed precision, similar to MixedPrecisionPolicy of FSDP2.

@EquationWalker good catch!

weifengpy avatar Sep 24 '25 19:09 weifengpy

my request changes is mainly on 2d mesh. we should target 1d mesh for landing. it's a user contract in public facing api

I think the use of 2D mesh has something to do with the FSDPParamGroup user contract. When passing a 2D mesh, FSDPParamGroup treats it as an HSDP and then shard parameters in the second dimension and replicate parameters in the first dimension. If you pass a 1D Mesh, FSDPParamGroup will shard parameters on this mesh instead of replicating them.

pytorch/issues#159013 mentioned adding reshape operation to mesh, but it seems that pytorch has not implemented it.

One solution would be to recreate a new 2D mesh (N,1) using the 1D mesh shape information (N,) inside apply_ddp, but this would create new communication groups, and I'm not sure if this is an expensive operation.

If replicate_with_fsdp could internally convert any 2D mesh shape (N,M) to (N*M,1), and any 1D mesh shape (N,) to (N,1), perhaps we can use fully_shard and replicate_with_fsdp in combination.

cc @anshul-si @tianyu-l

EquationWalker avatar Sep 25 '25 02:09 EquationWalker

The 1D mesh as the input has been in the roadmap but there are two issues.

  1. We need DeviceMesh unflatten support support which the PR is being reviewed but is not landed yet. But this is just a WIP.

  2. iirc, there is a concern about the input mesh being 1D but named_parameters() and state_dict() will actually return 2D DeviceMesh, which can cause inconsistent views for users.

@anshul-si @mori360 Please correct me if I'm wrong for the second issue.

fegin avatar Sep 25 '25 17:09 fegin

2. iirc, there is a concern about the input mesh being 1D but named_parameters() and state_dict() will actually return 2D DeviceMesh, which can cause inconsistent views for users.

I was thinking named_parameters() and state_dict() returning 1D

We need DeviceMesh unflatten support support which the PR is being reviewed but is not landed yet

which PR? I don't mind creating a new 2D mesh inside replicate to unblock for now, but the user contract has to be right on day 1

  • input: 1D mesh
  • output: model.parameters() and model.state_dict() being 1D

weifengpy avatar Sep 25 '25 20:09 weifengpy