xla icon indicating copy to clipboard operation
xla copied to clipboard

feat: abstraction of xla::OpSharding proto using wrapper class

Open kvshbg-aws opened this issue 5 months ago • 1 comments

This PR includes the changes related to abstracting xla::OpSharidng proto object into a torch_xla::OpSharding wrapper class.

This new class object will not have the requirements of xla::OpSharding (however, it will be an extension xla::OpSharding proto defined over here). We have defined the wrapper class in torch/xla which will construct an xla::OpSharding object with additional fields such as global_device_ids/global_tile_assignment and will have forwarded/proxy functions to xla::OpSharding . These forwarded functions will help user still make use of the same xla::OpSharding APIs as they normally would. We can also define torch_xla specific functions in this wrapper class to further use the extra fields that were stored during the initialization of the OpSharding object. This approach also allows the flexibility of converting the torch_xla::OpSharding object back to xla::OpSharding while lowering into HLO, thus, giving user the flexibility to use the abstracted class (and other additional fields stored) anywhere in the code base as needed, this is particularly useful since the XLA's HLOs are 0th indexed, hence we need to use the normalized_device_ids (starting from index 0) when lowering the program into the HLO, whereas we can still use the denormalized/global_device_ids in other places such as inside pjrt client to set the device_assignment using the user specified device_ids.

Component diagram for reference - Image (1)

Ref issue - https://github.com/pytorch/xla/issues/9390

kvshbg-aws avatar Jul 10 '25 00:07 kvshbg-aws

LGTM pending tests

pgmoka avatar Aug 26 '25 17:08 pgmoka