xla icon indicating copy to clipboard operation
xla copied to clipboard

[RFC] Enable DTensor SPMD APIs with XLA SPMD Backend

Open fhaolinaws opened this issue 6 months ago • 11 comments

🚀 Feature

The previous work #92909 on DTensor support with XLA backend enabled factory functions including distribute_tensor to create XLA sharded tensors when using XLA SPMD. However, it simply allows creating an XLAShardedTensor object but this object did not implement the full set of DTensor APIs necessary for users to utilize these tensors in parallelism workflows. This RFC proposes implementing DTensor APIs in XLAShardedTensor. This approach will enable PyTorch users to express tensor distributions consistently across different backends.

Currently, DTensor for XLA devices can be created through:

import torch
from torch.distributed import DeviceMesh, Shard, distribute_tensor

mesh = DeviceMesh("xla", list(range(world_size)))
big_tensor = torch.randn(100000, 88)
my_dtensor = distribute_tensor(big_tensor, mesh, [Shard(0)])

This example is from RFC #92909. With the changes in this proposal done, DTensor APIs and public properties can be used:

desired_layouts = [Replicate()]
# Reshard the tensor to be replicate on dim 0
my_dtensor.redistribute(desired_layouts)
# Create a zero tensor that has the same shape and sharding with my_tensor
zero_tensor = torch.distributed.tensor.zeros(my_dtensor.size(), device_mesh=mesh, placements=my_dtensor.placements)

Motivation

DTensor is becoming the new fundamental building block for distributed computation in PyTorch. Enabling DTensor with XLA devices offers several key benefits:

  • DTensor provides simple and generic APIs to express tensor distribution in SPMD style.
  • The integration reduces the overhead for users to learn new stacks and migrate their models.
  • The integration allows access to many existing parallelism utilities, including tensor parallel and sequence parallel wrappers.
  • The integration enables popular training frameworks built on top of DTensor, such as TorchTitan.
  • The integration provides the opportunity to consolidate XLA into DTensor ecosystem, making better alignment between XLA and native Torch.

Currently, Torch XLA provides XLAShardedTensor, mesh, and mark_sharding APIs for distributed training. Ongoing work #92909 has already enabled creating a sharded tensor with XLA device using DTensor standard factory functions when using XLA SPMD. However, XLAShardedTensor has not yet implement the full suite of DTensor interfaces, preventing users from using DTensor with XLA devices in the same way they use it with non-XLA devices.

Pitch

We propose implementing DTensor SPMD APIs and related public properties in XLAShardedTensor, so that XLAShardedTensor can function like a DTensor. MPMD API enablement is out of the scope because DTensor SPMD APIs with XLA backend compilation provide a complete solution of distributed computation, both functionally and non-functionally. While there may still be some benefit enabling DTensor + XLA MPMD, this proposal won't focus on it.

DTensor SPMD APIs and related public properties are listed below: (reference: https://docs.pytorch.org/docs/stable/distributed.tensor.html)

### redistribute performs necessary collective operations that redistribute the current DTensor from its current placements to a new placements, or from current DeviceMesh to a new DeviceMesh.
redistribute(device_mesh=None, placements=None, *, async_op=False)

### Return the full tensor of this DTensor. It will perform necessary collectives to gather the local tensors from other ranks in its DeviceMesh and concatenate them together. 
full_tensor(*, grad_placements=None)

### The DeviceMesh attribute that associates with this DTensor object.
device_mesh: DeviceMesh

### The placements attribute of this DTensor that describes the layout of this DTensor on the its DeviceMesh.
placements: tuple[torch.distributed.tensor.placement_types.Placement, ...]

While in #92909 DTensor factory function distribute_tensor has been implemented to support creating a sharded tensor from a regular tensor, there are other convenient factory functions that haven't been implemented. We list them here and propose to implement them for full compatibility: (reference: https://docs.pytorch.org/docs/stable/distributed.tensor.html)

torch.distributed.tensor.zeros(*size, requires_grad=False, dtype=None, layout=torch.strided, device_mesh=None, placements=None)
torch.distributed.tensor.ones(*size, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=None, placements=None)
torch.distributed.tensor.empty(*size, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=None, placements=None)
torch.distributed.tensor.full(size, fill_value, *, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=None, placements=None)
torch.distributed.tensor.rand(*size, requires_grad=False, dtype=None, layout=torch.strided, device_mesh=None, placements=None)
torch.distributed.tensor.randn(*size, requires_grad=False, dtype=None, layout=torch.strided, device_mesh=None, placements=None)

While there's existing support to convert a DTensor DeviceMesh to an XLA Mesh and further use it to create a sharded tensor on XLA device, DeviceMesh can be also used in the form of a submesh(a mesh with a subset of dimensions of the root mesh that can be conveniently used when composing multiple sharding strategies), which we don't see the support provided in XLA. This would require implementation of following APIs: (reference: https://github.com/pytorch/pytorch/blob/60abb0d3273749cb2a7d583c7c2863bd2819e87e/torch/distributed/device_mesh.py)

_MeshEnv.create_sub_mesh(self, device_mesh: "DeviceMesh", submesh_dim_names: tuple[str, ...], submesh_dims: list[tuple[int, ...]],)
_MeshEnv.get_root_mesh(self, device_mesh: "DeviceMesh")
DeviceMesh.__getitem__(self, mesh_dim_names: Union[str, tuple[str, ...]])

High-level Approach

Changes to XLAShardedTensor

The first change needs to be done is we will let XLAShardedTensor inherit DTensor. This makes sure XLAShardedTensor can be used at the places where DTensor is required. We implement SPMD APIs specified above while we explicitly throw exceptions for MPMD API calls, with the proper message conveyed to users that MPMD APIs are disabled when using XLA SPMD.

DTensor.redistribute

The redistribute API reshards a DTensor from its current placement to a new placement, by calling collectives. XLA doesn't have an explicit control on resharding a tensor, but this can be achieved by simply cloning the existing tensor and do a mark_sharding on new tensor. Alternatively, we could create a new XLA reshard API to support resharding, which can hide the clone in behind and possibly handle buffer donation so this process can be more efficient.

def redistribute(device_mesh=None, placements=None, *, async_op=False):
    ...
    clone_tensor = this.clone().detach()
    return distribute_tensor(clone_tensor, device_mesh, placements)
DTensor.full_tensor

The full_tensor API returns the global tensor. With XLA, we can simply utilize the redistribute API with replicate placements.

DTensor Properties

For device_mesh and placements, in XLAShardedTensor we don't have these objects saved. We will need to save them. As long as XLAShardedTensor starts to inherit DTensor, these fields will be there already.

DTensor Factory Functions

All the factory functions that need to be implemented above(zeros, ones etc.) rely on a DTensor internal helper function _dtensor_init_helper (referece). This helper doesn't utilize distribute_tensor factory function because the helper takes advantage of the traits of constant tensors so it can initialize only the local tensor and mark them distributed instead of distributing a global tensor. However, when using XLA SPMD the buffers are allocated with sharded shape anyway so we can simply replace the implementation with distribute_tensor

def _dtensor_init_helper(
    init_op,
    size: torch.Size,
    device_mesh: Optional[DeviceMesh] = None,
    placements: Optional[Sequence[Placement]] = None,
    **kwargs,
) -> DTensor:
    ...
    if device_mesh.device_type == 'xla':
        # initialize the global tensor
        if init_op == torch.full:
            fill_value = kwargs.pop("fill_value", 0)
            global_tensor = init_op(size, **kwargs)
        elif init_op == torch.rand or init_op == torch.randn:
            # this tensor meta is not used except `shape`
            dtype = kwargs.get("dtype", torch.get_default_dtype())

            tensor_meta = TensorMeta(size, (0,), dtype)
            spec = DTensorSpec(device_mesh, tuple(placements), tensor_meta=tensor_meta)

            if random.is_rng_supported_mesh(device_mesh) and not random._rng_tracker:
                random._rng_tracker = random.OffsetBasedRNGTracker(device_mesh)

            assert random._rng_tracker is not None
            with random._rng_tracker._distribute_region(spec):
                global_tensor = init_op(size, **kwargs)
        else:
            global_tensor = init_op(size, **kwargs)

        return distribute_tensor(global_tensor, device_mesh, placements)
DeviceMesh Submesh Utils

Submesh is a slice of the root DeviceMesh by indicating a dim name. For example, a 2-d mesh mesh_2d['dp', 'tp'] with 8 ranks can have a submesh mesh_2d['dp'] with 4 ranks [0,1,2,3] if the current rank is one of [0,1,2,3]. Torch XLA doesn't have a built-in support for submesh. To make this work, we can enable submeshes in Torch XLA. This could simply be a convenient helper class on top of Mesh that handles indexing of dim names and mapping to root mesh dims. Alternatively, we can avoid introducing the concept of submesh in Torch XLA, instead, DTensor submesh needs to be handled and mapped back to root mesh when converting to XLA mesh.

Alternatives

The goal of this RFC is to take advantage of the generic and simple DTensor API and provides consistent user experiences across different backends. This could potentially be achieved by integrating XLA backend with other APIs, too. Considering the native Torch ecosystem, wide user base and existing integration with XLA, we think it's beneficial to continue this integration, with other candidates considered later.

CC @bfolie @qihqi @GleasonK @aws-rhsoln @rpsilva-aws

fhaolinaws avatar Jun 27 '25 21:06 fhaolinaws

XLA doesn't have an explicit control on resharding a tensor, but this can be achieved by simply cloning the existing tensor and do a mark_sharding on new tensor

Is it not possible to mark_sharding on top of an already-sharded tensor? I recall a discussion to this effect, that it works in some cases but not others, and that there's a way to make it work in all cases. Which seems like it would be more effiicient in many cases than a full gather and reshard. Does this ring a bell to you?

bfolie avatar Jun 30 '25 14:06 bfolie

@bfolie We have https://github.com/pytorch/xla/pull/9203, which may help with most variants. If used correctly, this should be preferred over a clone, since that does not do in-place tensor sharding (e.g. mark_sharding_with_gradients).

rpsilva-aws avatar Jun 30 '25 17:06 rpsilva-aws

Thanks @rpsilva-aws, yes that is the discussion I was thinking of

bfolie avatar Jun 30 '25 17:06 bfolie

Yes, the API in #9203 should be usable for this case.

fhaolinaws avatar Jun 30 '25 19:06 fhaolinaws

I think the redistribute api implementation might not be just as simple as annotating the tensor with new sharding. It might be one of the things we have to do, but we might also need to adjust few more APIs:

  1. _ToTorchTensor and FromTorchTensor
  2. Need to handle sharding propagation in PyTorch . This just needs proper handling of Placements and sharding. During the sharding propagation pass in PyTorch, it propagates the placements and sharding to some of the tensors in the model flow (until it hits a replicated tensor), might be valuable to convert all those into mark_shardings which can result in extra hints for the partitoner.

aws-rhsoln avatar Jun 30 '25 20:06 aws-rhsoln

@aws-rhsoln thanks for your input. For 1, I think a proposed solution is we simply do nothing for to_local and from_local functions of DTensor, which are the function that utilize the classes you mentioned. The reason is these functions are used to convert already partitioned local tensor to DTensor, while in XLA SPMD there's no view of local tensors. There may be concern on backward compatibility when users have their training script using local tensor and want to migrate to using DTensor with XLA SPMD. But anyway when they switch to XLA SPMD, they need to make adaptive changes so local tensor is no longer existing. One example is the dataloader change. When using native Torch without XLA SPMD, the batch_size set for dataloader is for the mini batch. The input data on each rank are local and sharded. When switching to XLA SPMD, users need to adjust the batch_size to be for global batch. So local tensor is no longer a issue to be handled.

For 2, I think this could be an add-on and I'll create another sub-issue for this. It should be beneficial in some scenarios if we have an additional sharding propagation in the frontend, including easier tracing of the propagation process. But it comes with the cost that developers need to be aware of 2 sharding propagation framework with their individual rules and cost models, which we might need to be careful about. Talking about the gaps to enable this propagator, the main thing we need to take care is this propagator also does the thing to partition the tensors and insert collectives, which is not what we want for GSPMD so we need to replace considerable implementations. Things include we need to do are having correct DTensorSpec for the XLAShardedTensor, and provide SPMD solution for the DTensor shard utils. Currently, if we want to ignore this propagator and only use the GSPMD one, we can change the DTensor op dispatcher to exclude XLA device.

fhaolinaws avatar Jul 02 '25 23:07 fhaolinaws

Some naive questions:

  1. The doc describe going from PyTorch native to XLA equivalence. Should there be a path going the other way?
  2. Can the propagator be the main one, and disable the XLA SPMD one?
  3. Would going to FX graph via dynamo/torch.compile then convert to HLO a better route?
  4. Could you describe which changes would go to which package, torch vs torch-xla?

jeffhataws avatar Jul 03 '25 16:07 jeffhataws

@jeffhataws for your questions:

  1. I'm not sure if I understand your question. If you are asking if we can support XLA device in Torch native, like putting XLA handling logic in DTensor, I would expect this will lead to a lot of if-else handling for XLA, which could pollute the code.
  2. I think if we want to use the DTensor propagator with XLA SPMD, what it can do is in behind is just putting additional mark_sharding. The annotation still needs to be taken care of by the GSPMD propagator and partitioner. If using MPMD, we can rely on DTensor propagator, but it's not the focusing topic for this RFC.
  3. I think this could be a replacement of the solution to implement DTensor interface in XLAShardedTensor. But it comes with more dependencies and risks. Dynamo needs to be evaluated since it has some uncertainty like it breaks the graph which could harm training performance. The work will be dependent on the progress of supporting FX graph lowering to HLO, too. This RFC continues the solution in #92909 which is relatively simple and determined.
  4. The API implementation changes would all go to torch-xla. There could be some DTensor dispatcher and initialization changes going to torch

fhaolinaws avatar Jul 03 '25 23:07 fhaolinaws

To fit this proposal into the big picture, using Torch XLA SPMD backend is one solution to enable DTensor API with XLA, in which GSPMD sharding propagator and partitioner are utilized as the solution for the sharding problems. Another solution could be using Torch XLA MPMD backend, in which the sharding can be handled in the frontend with DTensor native implementation. A good thing about this separation is Torch XLA MPMD adapts more naturally to native Torch multi-process implementations so allows seamingless integration(Both non-dynamo and dynamo), including frontend handling of collectives, while for Torch XLA SPMD, the exisitng DTensor implementations do considerable amount of duplicate work it's good to keep only the APIs. So we can keep the door open to enable DTensor with Torch XLA MPMD. Then it would be users' choice whether they want a native Torch solution or XLA solution for sharding.

fhaolinaws avatar Jul 08 '25 20:07 fhaolinaws

Thanks for this proposal and context @fhaolinaws. As we've discussed I think this is a good approach. It improves the user experience, provides both collective-based and XLA-based options, and should be low-risk and straightforward to implement.

bfolie avatar Jul 08 '25 20:07 bfolie

I think to @jeffhataws 's point on using PT's propagator, i think that might be beneficial and would be as per what native pytorch does. This is what I was trying to hint at about using sharding propagator from PyTorch. Might be worth investing. I did a small PoC and seems like using the sharding propagator from PyTorch can allow us to get per tensor sharding info which we can later use to mark_shard tensors.

aws-rhsoln avatar Jul 21 '25 17:07 aws-rhsoln