xla
xla copied to clipboard
Add sharding devices to XlaCompileOptions and plumb them through from JAX.
Add sharding devices to XlaCompileOptions and plumb them through from JAX.
This is necessary to support MPMD parallelism in McJAX, since the PjRt-IFRT executable's output shardings can no longer be built with the addressable devices from the PJRT executable, in the case where the executable has no addressable devices.