Normalize `tile_assignment` after constructing the `xla::OpSharding` object
Currently users cannot define the mesh with device_ids starting from anything except 0. This blocks us from defining sub-meshes, and also blocks the user from using localized SPMD within a single node. The error message being -
RuntimeError: Passing an empty index list to Tensor::index() is not valid syntax
To resolve this issue we propose normalizing the tile_assignment object after creating an xla::OpSharding inside CreateOpSharding function. This will also make sure that the tile_assignment passed on through the entire process from XlaMarkSharding() (inside xla_sharding_util.cpp file) through SetSharding() (insdie ir.cpp file) to the HLO which will be generated.
This will allow the users to define submeshes which start with device_ids other than 0 and also pave the way for defining submeshes as well as inter-node localized SPMD.
We propose to make use of an anonymous function inside xla_sharding_util.cpp file which can be used to normalize the tile_assignment_devices field of the OpSharding object