torchtitan icon indicating copy to clipboard operation
torchtitan copied to clipboard

Use new DeviceMesh unflatten to rewrite parallel_dims

Open fegin opened this issue 4 months ago • 0 comments

Summary This PR utilizes the latest APIs provided by DeviceMesh to simplify the creation of all different meshes.

The design philosophy is as follow:

  1. Create one world mesh with the shape as [world_size,]
  2. Create all 1-D submeshes by using 1) unflattening from the world mesh, or 2) slicing and flatten from other derived meshes.
  3. ParallelDims now provides an API, get_mesh(), which accepts str or list[str]. When the argument is str, the API directly return the corresponding 1-D submesh. If the argument is list[str], the dim names will be used to concatenate to form a n-D device mesh.

fegin avatar Aug 29 '25 05:08 fegin