xla icon indicating copy to clipboard operation
xla copied to clipboard

Create Local SPMD

Open pgmoka opened this issue 7 months ago • 0 comments

EDIT: Rather than creating an new RFC, I have decided to expand this GitHub issue with more information on achieving Local SPMD

Context

Previous work has been done to attempt to achieve this in https://github.com/pytorch/xla/pull/8810. This first attempt got paused due to a couple issues. The primary limitation of that initial PR was that Local SPMD was locked to devices associated with a specific host.

For a better experience, we want to be able to associate devices independent of host to the Local SPMD instance to create a better user experience. It is also necessary for more wide range use cases.

Now that we are working to pursue SPMD+MPMD, we also have an incentive to achieve Local SPMD as the first steps towards that goal.

The basics of achieving Local SPMD

To achieve local SPMD, we need to pass to our computer runtime (currently PJRT) the association between mesh ordinal devices (devices as referenced by the user), and PJRT IDs (like the hardware ID of each device).

This is done in xla::DeviceAssignment and is the point where we need to make our modifications. You can see where and how this was done in the original prototype.

We need to change this such that the ordinals can be set by the user, and do not necessarily need to match the total number of devices available.

Refactoring xla::OpSharding

Currently meshes are defined by the user using xla::OpSharding. This abstraction is relatively low level. It works well when only operating with all devices, but at any other variation it encounters issues. Per example, users must start their mesh at 0.

To resolve this, we should create a new level of abstraction for xla::OpSharding which will allow users to create their own abstraction which can then be translated to the PJRT. Through this class.

Currently in the PyTorchXLA there are two other objects abstracting xla::OpSharding in

  1. torch_xla::ShardingSpec in https://github.com/pytorch/xla/blob/r2.7/torch_xla/csrc/tensor_common.h
  2. ShardingSpec in https://github.com/pytorch/xla/blob/r2.7/torch_xla/csrc/tensor.h#L264-L276

Both of these are very similar. I believe we should be able to merge both, and then turn this class into torch_xla::OpSharding. This new class abstraction will keep track of the global ordinances and their relation to PJRT IDs for the sharded tensor. It will also keep track of what is necessary to create xla::OpSharding and serve as a constructor for those objects.

This should create a consistent experience for sharding across our code, and let us make modifications to sharding easier in the future.

Exposing new torch_xla::OpSharding abstraction

We will expose the new torch_xla::OpSharding object at Python level so we may refer to it from Python. We will then modify get_op_sharding to generate the new referenced object.

Changes for object creation should happen in the Mesh class to relax its requirements such that not all devices need to be used in the mesh.

Considering order of execution

I suggest we:

  1. Refactor torch_xla::ShardingSpec and ShardingSpec into torch_xla::OpSharding
  2. Enable torch_xla::OpSharding in get_op_sharding
  3. Add the ability for torch_xla::OpSharding to track mesh ordinal devices and PJRT IDs, and estabilish the necessary relationships to create the correct xla::DeviceAssignment

pgmoka avatar May 16 '25 14:05 pgmoka