xla icon indicating copy to clipboard operation
xla copied to clipboard

[SPMD] auto-sharding PoC

Open yeounoh opened this issue 1 year ago • 0 comments

This implemented a PoC prototype on XLA:TPU, as described in #6322

Aside from the auto-sharding feature, XLA_SPMD_AUTO or

import torch_xla.runtime as xr
xr.use_spmd(auto=True)

I also adapted xla::OpSharding::UNKNOWN to mark unannotated tensors in SPMD mode.

Some known limitations that we will address in follow-ups:

  • XLA:GPU is not supported
  • TPU pod is not supported

I will post a separate PR to demonstrate DTensor API integration for auto-sharding.

cc @baoleai

yeounoh avatar Mar 12 '24 00:03 yeounoh