xla icon indicating copy to clipboard operation
xla copied to clipboard

Abstract xla::OpSharding

Open pgmoka opened this issue 7 months ago • 3 comments

Currently our sharding mechanisms work by utilizing xla::OpSharding (first defined here). We later pass xla::OpSharding to the XLA compiler directly. The sharding operation there is propagated until our compiler creates xla::DeviceAssignment based on the mesh device IDs and is passed to xla compiler. The PjRtDevice::id(). xla::OpSharding must start from 0. This works great with global SPMD.

For Local SPMD, each mesh’s IDs may not necessarily begin from 0. Consider that:

Global SPMD: mesh0{0, 1, 2, 3, 4, 5, 6, 7} Local SPMD: mesh0{0, 1, 2, 3}; mesh1{4, 5, 6, 7}

This creates a situation where we need to normalize the logical mesh to start with 0 before passing it to the compiler. This is not ideal because it obscures that work, and creates possible edge cases.

For achieving this, we suggest the new sharding API: torch_xla::OpSharding. This new object will not have the dependencies of xla::OpSharding, and will contain:

  • The global IDs associated with the devices from the local mesh
  • PjRTDevice::id()

This abstraction later will make it easier to interact with mark_sharding as well as improving the quality of our existing code. From a code structure, very little should we change as we will simply start using our new torch_xla::OpSharding object

We will also want to plug locally_addressable_devices to return the mesh from the local SPMD mesh.

pgmoka avatar May 16 '25 14:05 pgmoka

Is there a possibility to not make API changes but instead allow setting a flag for when we want PJRT to know about less than the global view?

On a single node, multi-GPU, TF used to have a CUDA_VISIBLE_DEVICES for this purpose and I used it successfully. Within python/cpp, the process only knew about DEVICE 0 and 1 if I set CUDA_VISIBLE_DEVICES=6,7.

yaoshiang avatar May 16 '25 15:05 yaoshiang

Is there a possibility to not make API changes but instead allow setting a flag for when we want PJRT to know about less than the global view?

On a single node, multi-GPU, TF used to have a CUDA_VISIBLE_DEVICES for this purpose and I used it successfully. Within python/cpp, the process only knew about DEVICE 0 and 1 if I set CUDA_VISIBLE_DEVICES=6,7.

Sorry for the long delay on responding @yaoshiang. I am not fully aware with the CUDA approach to this. I think this could work to only enable Local SPMD, but we also have the long term goal of enabling SPMD+MPMD.

While a CUDA_VISIBLE_DEVICES equivalent flag could work well for just enabling Local SPMD, I don't think it will work well with SPMD+MPMD. By making this abstraction, we set-up for SPMD+MPMD.

pgmoka avatar May 28 '25 17:05 pgmoka

I have been investigating how the specific implementation of this bug might look like. I scoured our code for where xla::OpSharding is being used, and here are my findings:

Previously the ShardingSpec object that seems to have been created to abstract xla::OpSharding in tensor_common.h and tensor.h. I think it might make sense for us to expand it as the abstract of xla::OpSharding to keep consistency rather than having multiple abstractions.

Some locations I can guarantee we will be adding our abstraction:

  1. The computational clients:
    • pjrt_computation_client.h
    • pjrt_computation_client.cpp
    • ifrt_computation_client.h
    • ifrt_computation_client.cpp
    • computation_client.h
  2. init_python_bindings.cpp

I am still looking at if we need to do changes related to lazy tensor, which would impact:

  1. ir.h/cpp
  2. lowering_context.cpp
  3. device_data.cpp

Right now we have a util file for doing sharding operations which operates with both the abstract XLATensor::ShardingSpec and xla::OpSharding. I think we should consider making XLATensor::ShardingSpec a class which hides some of these xla::OpSharding operations where appropriate. This should help clean our code and make sharding more consistent.

Ideally this change serves as a way to also eliminate tech-debt from our code related to device IDs, and helps clarify what and where things like local/global device IDs are used as well as their relationship between PJRT ID. To do this right, we need to look at other place is our code that are working with those IDs to keep consistent behavior. I am currently in the process of taking a look at these IDs, and will note back here with anything I find.

I think it will be possible to identify what in this refactor is required to enable Local SPMD, and what is necessary for the service long term, but before we start making changes, I would like to take a look at:

  • The state of IDs in our code
  • Requirements surrounding lazy tensor, and whether or not we should abstract its xla::Opsharding calls

pgmoka avatar May 29 '25 18:05 pgmoka