xla
xla copied to clipboard
[RFC] XLAShardedTensor & Sharding Annotation API
🚀 Feature
We propose XLAShardedTensor
to represent a sharded tensor that wraps around torch.Tensor
, and mark_sharding()
API for tensor sharding annotation. XLAShardedTensor allows annotating tensors with sharding specs and dispatching the annotations to the XLA backend for XLA GSPMD support in PyTorch/XLA.
Usage Example
import torch
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_sharding as xs
m1 = torch.randn(8, 4).to(xm.xla_device())
mesh_shape = (2, 4) # device mesh
# Mesh partitioning, each device holds 1/8-th of the input
partition_spec = (0, 1)
m1_sharded = xs.mark_sharding(m1, mesh_shape, partition_spec)
# XLAShardedTensor behaves like a unpartitioned native tensor:
# - we can apply native torch.Tensor ops and nn.layers
# - print the unpartitioned tensor when sent to CPU
# - provides access to the list of local shards after lazy execution
assert isinstance(m1_sharded, XLAShardedTensor) == True
Motivation
Our goal is to support GSPMD model sharding in PyTorch -- this would allow a user to bring in PyTorch models implemented (as if) on a single device and annotate a few tensors with desired sharding specs to run efficient model parallelism. Such an automated approach to model sharding allows the XLA compiler to optimize the entire computation graph end-to-end and frees up the user from implementing sharded version of ops with proper collectives in place.
The current PyTorch ShardedTensor
abstraction RFC provides simple primitives to express sharded tensor also, and the ShardedTensor
API provides a set of convenience helper functions to shard tensor or model parameters with sharding specs. Under the hood, it requires manual/explicit implementation of sharded ops (e.g., sharded version of torch.nn.functional.linear
) and careful injection of collective comms. And the abstraction represents sharded tensors directly, not the sharding annotation for XLA compiler-based sharding that will take place lazily and support the xla
backend.
Pitch
To enable our XLA compiler-based sharding, we propose XLAShardedTensor
and mark_sharding
API. In this section, we also describe how user can specify different tensor sharding strategies for the sharding annotation.
XLAShardedTensor
The main use case for XLAShardedTensor is to annotate a native torch.tensor (on a single device) with sharding spec. The annotation takes place immediately, but the actual sharding of the tensor happens lazily. Once a tensor is annotated and wrapped inside a XLAShardedTensor, it can be passed to existing PyTorch ops and nn.Module layers as torch.Tensor. This is critical to ensure that layers and tensor ops can be stacked together as before, which means that the user does not need to rewrite the existing single device model for sharded computation. Namely, XLAShardedTensor will satisfy the following requirements:
- furthermore, XLAShardedTensor as a torch.Tensor subclass should work directly with native torch ops and module.layers. We use
__torch_dispatch__
to send XLAShardedTensor to the XLA backend, and PyTorch/XLA should be able to retrieve attached sharding annotations to trace the graph with them and invoke SPMDPartitioner. - the handles to the local shards are materialized strictly after the lazy execution.
- the local shards (or replicas) are gathered and materialized to CPU when accessed after lazy execution.
@dataclass
class XLAShard:
data: torch.Tensor
rank: int
class XLAShardedTensor(torch.Tensor):
"""
A wrapper around `torch.Tensor` with sharding annotation
for XLA SPMD auto-sharding. The wrapped tensors are unwrapped
for IR tracing and converted to HLO graph with sharding annotations;
XLA SPMDPartitioner takes a pass, propagating and injecting collectives
to the graph before compilation.
"""
# XLAShardedTensor behaves like a unpartitioned,
# combined tensor on the host machine. When user annotates,
# this is simply set to the input tensor. When an XLA partitioned
# output tensor returns (or sharding propagated intermediate tensors)
# as XLAShardedTensor, the backend gathers global data across devices
# and materialize and set `global_tensor` on the host; the actual device
# data still remain on individual device as sharded or replicated.
# Note: we should drop this reference, and force all gather on each access.
global_tensor: torch.Tensor
# Shards on the devices are materialized/available after the lazy
# execution of the SPMDPartitioned HLO graph; otherwise,
# local_shards is set to `None`. Each XLAShard points to
# torch.Tensor (xla::device_data).
# Note: we can consider returning a callback or even define
# sharding at XLAShardedTensor construction after pjrt migration.
local_shards: List[XLAShard] = None
__slots__ = ['global_tensor']
@staticmethod
def __new__(cls, elem: torch.Tensor, *args, **kwargs):
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
cls,
elem.size(),
strides=elem.stride(),
storage_offset=elem.storage_offset(),
dtype=elem.dtype,
layout=elem.layout,
device=elem.device,
requires_grad=kwargs.get("requires_grad", False))
r.global_tensor = elem.detach() if r.requires_grad else elem
return r
@property
def sharding_spec(self):
return NotImplemented
@property
def shards(self):
return NotImplemented
def __repr__(self):
return f"XLAShardedTensor({self.global_tensor})"
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
"""
The dispatcher allows the unwrapped torch.Tensor to re-dispatched to the
`xla` backend as XlaTensor, and the XlaTensor with an associated sharding spec
to be received and wrapped as XLAShardedTensor.
"""
def unwrap(elem):
return elem.global_tensor if isinstance(elem, XLAShardedTensor) else elem
def wrap(elem):
return XLAShardedTensor(elem) if isinstance(elem, torch.Tensor) else elem
# no_dispatch is only needed if you use enable_python_mode.
# It prevents infinite recursion.
with no_dispatch():
# re-dispatch to C++
rs = tree_map(wrap,
func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
return rs
mark_sharding API
Users can annotate native PyTorch tensors using the mark_sharding
API. This takes torch.Tensor
as input and returns a XLAShardedTensor
as output.
def mark_sharding(t: Union[torch.Tensor,
XLAShardedTensor], mesh_shape: Tuple[int],
partition_spec: Tuple[Union[int, None]]) -> XLAShardedTensor:
"""
Annotates the tensor provided with XLA partition spec. Internally,
it annotates the corresponding XLATensor as sharded for the XLA SpmdPartitioner pass.
Args:
t (Union[torch.Tensor, XLAShardedTensor]): input tensor to be annotated with partition_sepc.
mesh_shape (Tuple[int]): A int tuple describing the logical topology
of the device mesh, and each element describes the number of devices in
the corresponding axis.
partition_spec (Tuple[int, None]): A tuple of device_mesh dimension index or `None`.
This specifies how each input rank is sharded (index to mesh_shape) or replicated (None).
Examples
—------------------------------
mesh_shape = (4, 2)
input = torch.randn(8, 32).to(xm.xla_device())
# 4-way data parallel
input = xs.mark_sharding(input, mesh_shape, (0, None))
linear = nn.Linear(32, 10).to(xm.xla_device())
# 2-way model parallel
linear.weight = xs.mark_sharding(linear.weight, device_mesh, (None, 1))
output = linear(input)
# full replication
output = xs.mark_sharding(output, device_mesh, (None, None))
"""
return NotImplemented
Sharding Specification
mark_sharding
API takes mesh_shape
and partition_spec
as input to annotate tensor with different sharding specifications, like replicated, tiled or partially tiled:
- mesh_shape (Tuple[int]): A int tuple describing the logical topology of the device mesh, and each element describes the number of devices in the corresponding axis.
- partition_spec (Tuple[int, None]): A tuple of device_mesh dimension index or
None
. This specifies how each input rank is sharded (index to mesh_shape) or replicated (None).
partition_spec
has the same rank as the input tensor, and each dimension describes how the corresponding input tensor dimension is sharded across the device mesh (logically defined by mesh_shape
). For example, an 8x32 input tensor t
can be annotated as partially tiled over a 4x2 device mesh as follows:
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_sharding as xs
...
t = torch.randn(8, 32).to(xm.xla_device())
mesh_shape = (4, 2)
partition_spec = (1, None)
sharded_t = xs.mark_sharding(input, mesh_shape, partition_spec)
partition_spec = (1, None)
means that the first dimension (0-th index) of the input is sharded across the two device columns (mesh_shape[partition_spec[0]] = 2
) and the second dimension if replicated as specified by partition_spec[1] = None
. Similarly, one can replicate partition_spec = (None, None)
or fully shard partition_spec = (0, 1)
across the devices.
Alternatives
PyTorch only supports manual sharding API and primitives, like ShardedTensor
abstraction RFC. This is great for more advanced users who would implement and run custom sharding strategies. XLAShardedTensor
sharding API focuses on brining in automated, XLA compiler-based sharding to the PyTorch users.
Additional context
We also have a separate RFC for a high-level GSPMD API that will wrap an existing PyTorch module
and apply sharding specs to select tensors. The high-level API will use XLAShardedTensor
and mark_sharding
as building blocks and make sharding annotation experience seamless and easy for the user.
cc @ronghanghu @JackCaoG @miladm @pritamdamania87 @wanchaol @fduwjj @mrshenli
This would be a great feature. Looking forward to it!