torchtitan
torchtitan copied to clipboard
Use new DeviceMesh unflatten to rewrite parallel_dims
Summary This PR utilizes the latest APIs provided by DeviceMesh to simplify the creation of all different meshes.
The design philosophy is as follow:
- Create one world mesh with the shape as [world_size,]
- Create all 1-D submeshes by using 1) unflattening from the world mesh, or 2) slicing and flatten from other derived meshes.
- ParallelDims now provides an API, get_mesh(), which accepts str or list[str]. When the argument is str, the API directly return the corresponding 1-D submesh. If the argument is list[str], the dim names will be used to concatenate to form a n-D device mesh.