xla icon indicating copy to clipboard operation
xla copied to clipboard

Normalize `tile_assignment` after constructing the `xla::OpSharding` object

Open kvshbg-aws opened this issue 6 months ago • 0 comments

Currently users cannot define the mesh with device_ids starting from anything except 0. This blocks us from defining sub-meshes, and also blocks the user from using localized SPMD within a single node. The error message being -

RuntimeError: Passing an empty index list to Tensor::index() is not valid syntax

To resolve this issue we propose normalizing the tile_assignment object after creating an xla::OpSharding inside CreateOpSharding function. This will also make sure that the tile_assignment passed on through the entire process from XlaMarkSharding() (inside xla_sharding_util.cpp file) through SetSharding() (insdie ir.cpp file) to the HLO which will be generated. This will allow the users to define submeshes which start with device_ids other than 0 and also pave the way for defining submeshes as well as inter-node localized SPMD.

We propose to make use of an anonymous function inside xla_sharding_util.cpp file which can be used to normalize the tile_assignment_devices field of the OpSharding object

kvshbg-aws avatar Jun 18 '25 20:06 kvshbg-aws