physicsnemo icon indicating copy to clipboard operation
physicsnemo copied to clipboard

🚀[FEA]: Add mechanism for a dataloader to be aware of model parallel partitioning

Open akshaysubr opened this issue 2 years ago • 1 comments

Is this a new feature, an improvement, or a change to existing functionality?

New Feature

How would you describe the priority of this feature request

Medium

Please provide a clear description of problem you would like to solve.

With model parallel training, each process might need only a subset of the full input. This is nice to reduce the I/O requirements per process but the model parallel partitioning is typically decided by the model implementation and the dataloader is unaware of it. It would be good to add some mechanism for a dataloader to be aware of this partitioning either by directly querying DistributedManager or allow the user to query the model instance for this info and pass that down to the dataloader.

Describe any alternatives you have considered

None

akshaysubr avatar Sep 07 '23 17:09 akshaysubr

The incoming upgrades to DistributedManager will enable DeviceMesh in modulus. While mesh configuration and dimension naming will be a user choice, the general workflow will look like this:

mesh_sizes = [-1, 4]
mesh_names = ["data_parallel", "model_parallel"]
mesh = dm.init_device_mesh(mesh_sizes, mesh_names)

This will configure pytorch's process groups by creating a DeviceMesh that coordinates GPUs into a hierarchy of groups. Assuming the user is consistent with naming schemes, if the model is using model_parallel to designate that a model will be distributed on 4 GPUs, each dataloader process could do the following:


model_parallel_mesh = dm.mesh()['model_parallel']
model_parallel_group = model_parallel_mesh.get_group()
# Requires process group + global rank:
this_model_parallel_rank = dist.get_group_rank(model_parallel_group, dm.rank)

From there, the dataloader is free to optimize data loading and partitioning based on the local rank exclusively within the model-parallel group, and be confident the local rank will be consistent in the model's code space too. This is also extensible to more levels of parallelism with larger (high rank) DeviceMesh objects.

@akshaysubr would this meet your requirements?

coreyjadams avatar Jan 28 '25 20:01 coreyjadams