xla
xla copied to clipboard
Enable Local SPMD through torch_xla::OpSharding
Once we have refactored mark_sharding to utilize torch_xla::OpSharding, we will leverage it to implement Local SPMD. Through it we will store the correct global device association, and pass it to the compiler.
With the changes from https://github.com/pytorch/xla/issues/9183, this should be primarily around relaxing restrictions around the number of devices passed to executed Local SPMD.