xla icon indicating copy to clipboard operation
xla copied to clipboard

Create and Expose the `torch_xla::OpSharding` wrapper class instead of `xla::OpSharding` class

Open kvshbg-aws opened this issue 6 months ago • 2 comments

As a part of localized SPMD/submeshing effort we need abstract xla::OpSharding proto object with torch_xla specific wrapper class and expose the torch_xla::OpSharding object instead of xla proto object to the user

Currently we expose the xla::OpSharding object directly to the user when they call get_op_sharding object. This abstraction is relatively low level and it will work well when only operating with all devices, but not when users start their mesh at something other than 0.

To resolve this, we propose of creating a new torch_xla::OpSharding class which will be a wrapper of the xla::OpSharding proto. This wrapper class will have a constructor which takes in the xla::OpSharding object and the tile_assignment object. (The tile_assignment object will have the global_device_ids which will be further required down the stack while creating device_assignment for pjrt_client, hence we would like to pass it as a param for the constructor of this class)

The wrapper class needs to have forwarded methods, so that the user can still make use the same APIs as that of xla::OpSharding (for example xla::OpSharding sharding.type()) to access the proto’s fields/variables etc. The only difference is that the user will now have to make use of torch_xla::OpSharding object instead of the xla proto when calling the API’s. The plus side of having such abstraction is that it allows the users to define their own custom fields inside the torch_xla::OpSharding class, for example, we can define the global_tile_assignment field which will have the global_device_ids required down the line in our use-case of localized SPMD/sub-meshing.

Once we have the wrapper class ready, we will return it instead of xla::OpSharding when we return CreateOpSharding from xla_sharding_util.cpp file and thus the init_python_bindings.cpp will also incur some changes, where in we will convert the references made to xla::OpSharding to torch_xla::OpSharding

Note - changing the return type over here will also require us to change the type in all subsequent function/calls which make use of xla::OpSharding object

kvshbg-aws avatar Jun 18 '25 20:06 kvshbg-aws

I initially created https://github.com/pytorch/xla/issues/9334 and https://github.com/pytorch/xla/issues/9356 with the intent to achieve what this bug is talking of doing with torch_xla::OpSharding. The idea with those two is to standardize sharding abstractions under our backend.

This bug talks about the final set of changes. We can make it the parent bug for those two. @kvshbg-aws and @rpsilva-aws do you have any disagreements on that?

pgmoka avatar Jun 24 '25 18:06 pgmoka

Sounds good, feel free to move these. Thanks @pgmoka.

rpsilva-aws avatar Jun 24 '25 23:06 rpsilva-aws