🚀[FEA]: Distributed Training/Inference: handle scatter/gather better and more consistently
Is this a new feature, an improvement, or a change to existing functionality?
Improvement
How would you describe the priority of this feature request
Low (would be nice)
Please provide a clear description of problem you would like to solve.
Problem exists in model-parallel settings where not all ranks have valid tensors, mainly around gather and scatter routines.
Scatter
- scatter assumes a single tensor on a source rank which is distributed in parts across other ranks
- to be able to receive these chunks, however, these ranks need to know the
dtype, and other meta-information likerequires_gradto not break training pipelines - current solutions either require the user to specify these things on each rank, or assume empty "dummy" tensors on each rank that carry these information, these, however, might be not robust when registered in compute-graphs of autograd-frameworks
Gather
- the backward pass of gather is a a scatter call, so similar problems arise, although this case can be handled more easily by e.g. storing meta-data in the corresponding context of the
torch.autograd.Function - main issue rather arises in upstream layers if gather returns
Noneon all participating ranks, as it could be more informative to have an object carrying information about thisNonejust being the null-part of a distributed tensor which currently is valid on rank X
Potential Solution
- in general, we should make things more consistent throughout
- a potential solution would be to define something like a
TensorPlaceholderwhich carries meta-data on ranks where the tensor is currently not valid and is more informative than justNone
Describe any alternatives you have considered
No response
@stadlmax could you please provide an update on this issue?
I would say that the problem still exists in general. In our implementations, we should have enough workarounds in place to avoid these issues. If we want to simplify the lives of people implementing other distributed solutions, one could think of tackling that issue. I would say low priority.
Orthogonal to that, we can keep monitoring the progress on DTensor in upstream PyTorch which could be a solution to this minor problem.
On DTensor: It is unlikely that DTensor will ever be suitable for this task. The challenge is that DTensor explicitly assumes tensors are distributed across ranks as if you called torch.chunk and each chunk went to a subsequent device.
I have implemented ShardTensor in modulus as a relatively lightweight extension to DTensor to get around this, though it may not cover the exact cases seen here. In particular, scatter/gather are wrapped in syntax like:
Scatter a tensor:
# a is a tensor on rank 0
# target_mesh is a torch DeviceMesh indicating a specific group of GPUs
# target_placements is a list of placements from DTensor (len(target_placements) == mesh.ndim) that
# specifies exactly how, and what axis (or axes) a tensor is to be distributed on. Supports sharding and replication.
scattered_a = distribute_tensor(a, target_mesh, target_placements)
# ^ this is differentiable
(all)Gather a tensor:
gathered_a = scattered_a.full_tensor()
# ^ this is differentiable
One point is that this full_tensor operation implies an allgather, not a single-device-gather, because it relies on DTensor concepts which are largely focused on single-program, multi-gpu parallelism.
The above syntax works for DTensor and ShardTensor equally, though the DTensor gather/scatter will assume torch.chunk layouts as mentioned. ShardTensor will not make these assumptions but instead will use the appropriate (and differentiable) scatter_v/gather_v operations to coalesce a distributed tensor.
ShardTensor can also (correctly) work with this syntax, on every rank:
distributed_tensor = ShardTensor.from_local(local_slice, target_mesh, target_placements)
Which will work regardless of the local_slice size, as long as the set of global tensors have the same rank and shape such that they could be gathered into one tensor. Meaning, tensor dimensions must agree on all nodes, excluding dimensions that are being sharded.
We should take a look at the use cases for scatter/gather in modulus to see if it makes sense to apply this in existing code, or if it will be useful in new developments.
I believe this functionality is now handled with ShardTensor: https://docs.nvidia.com/physicsnemo/latest/user-guide/domain_parallelism_entry_point.html.
Please, open a fresh issue if more functionality is needed?