xla
xla copied to clipboard
[SPMD] auto-sharding PoC
This implemented a PoC prototype on XLA:TPU, as described in #6322
Aside from the auto-sharding feature, XLA_SPMD_AUTO or
import torch_xla.runtime as xr
xr.use_spmd(auto=True)
I also adapted xla::OpSharding::UNKNOWN to mark unannotated tensors in SPMD mode.
Some known limitations that we will address in follow-ups:
- XLA:GPU is not supported
- TPU pod is not supported
I will post a separate PR to demonstrate DTensor API integration for auto-sharding.
cc @baoleai