[RFC] Enable HSDP
Stack from ghstack (oldest at bottom):
- -> #518
This PR enables HSDP.
Discussions
1. How does trainer get DP mesh?
Right now, we flatten ["dp_replicate", "dp_shard"] into a flattened dimension. Because DeviceMesh currently does not support slicing a flattened dimension, we need to either a) flatten again, or b) bookkeep the flattened result. Why do we initialize all the device mesh in the beginning? That's a good strategy to avoid possible deadlock and timeout and it is easier for users to debug -- initializing a brand-new device mesh requires all ranks to participate.
What is the alternative solution? If DeviceMesh can support slicing a flattened dimension, then we can just slice out from the world mesh. Moreover, we may have to support slicing a flattened dimension. With HSDP + CP, we need to pass the following DeviceMesh to fully_shard, ["dp_replicate", "dp_shard_cp"] where dp_shard_cp is a flattened dimension.
However, if DeviceMesh supports slicing a flattened dimension, what will be the name? Currently, DeviceMesh implicitly concatenate the dimension names that form the flattened dimension. Is this too implicit? We also need to discuss this issue.
Conclusion: use named flatten + slicing, wait for the PRs from PyTorch.
2. How does TorchTitan expose HSDP to users?
Another UX issue is that how does TorchTitan expose HSDP? There are two ways, one is to expose dp_shard and dp_replicate to user. For DDP, dp_shard==1 and dp_replicate>1. For HSDP, dp_shard>1 and dp_replicate>1. For FSDP, dp_shard>1 and dp_replicate==1.
An alternative, which this PR uses, is to expose dp_type. Users explicitly specify FSDP, HSDP, DDP. So we need another way to express the two degrees. This PR currently expose dp_replicate but another suggestion is to let dp accept both int and Tuple[int, int].
3. Buffers synchronization
DTensor will implicitly synchronize the RNG status. However, there are buffers that are not DTensor. How do we ensure that these buffers are synchronized? This PR currently uses _sync_module_states_with_mesh to synchronize the module states including parameters and buffers. Another proposal is that users should set the random seed correctly and ensure the buffers are the same.
Conclusion: let users handle the RNG status.
TODO:
- Verifying the accuracy
- Verifying the performance.
- Verifying the accuracy with TP.
- Detect if user's PyTorch version can do HSDP.
mark
@tianyu-l I changed data_parallel_replicate_degree default to 1. But I don't see why ParallelDims logic would be simplified, are we not allowed data_parallel_replicate_degree to be -1? That will be a different story. It's true that toml files and test_runner.py can be simplified but I would prefer to explicitly for these files now that we have 2 data parallelism degrees. So I just leave them in the files.