xla icon indicating copy to clipboard operation
xla copied to clipboard

Changes needed for `device_assignment` in `PjRtComputationClient::Compile` to support submeshing/localized spmd

Open kvshbg-aws opened this issue 6 months ago • 1 comments

Currently xla::DeviceAssignment used inside PjRtComputationClient::Compile call makes use of the local devices/device_count for device assignment, but for submeshing or even for localized SPMD we would want to make use of the global_device_ids the user provided.

Once we have the wrapper class defined in #9390 we can update the sharding object inside PjRtShardedData struct to use torch_xla::OpSharding instead of the xla::OpSharding object (here and here). This will allow accessing the original_tile_assignment (with global device IDs) using the torch_xla::OpSharding sharding object, and hence the device_assignment for PJRT can make use of the device_ids that the user provided during the Mesh initialization

At the moment, with the current implementation of PjRtComputationClient::Compile, a user would not be able to specify a subset of addressable devices in a program. This results from the PjRtComputationClient inferring the number of participants from the PJRT local devices, hence, we would also have to extend the function to support sub-meshing and for that we will update the function to make use of the torch_xla::OpSharding sharding.tile_assignment() (and hence the global_device_ids) while creating the device_assignment ‘s (here and here).

This will further enable us to introduce some additional asserts on the device_ids of the submesh, for example, the submesh should be a subset of client_->addressable_devices ; all the sharded tensors within a process should have the same mesh/submesh i.e. have the same tile_assignment.

And to support the above, we would have to make changes to the CompileInstance object to include an additional field which will hold the DataPtr to the pjrt_sharded_data which we can then use to get the sharding (which will be a torch_xla::OpSharding object) and hence get the original_tile_assignment from the sharding object

kvshbg-aws avatar Jun 18 '25 20:06 kvshbg-aws

I think this is a nice expansion of https://github.com/pytorch/xla/issues/9357. Marking it as a duplicate of this one.

pgmoka avatar Jun 24 '25 18:06 pgmoka