xla icon indicating copy to clipboard operation
xla copied to clipboard

Behaviour of xm.all_gather() in SPMD mode

Open hshahTT opened this issue 4 months ago • 1 comments

❓ Questions and Help

I would like to confirm whether my MLIR compiler's handling of xm.all_gather() when running Torch-XLA in SPMD mode is correct.

Say I have the following:

  • A tensor t with shape [8192, 784]
  • A 2D named mesh (batch, model) of 8 devices in a [2, 4] configuration:
Device Mesh:
0 1 2 3
4 5 6 7

Now I do the following steps:

  1. Move the tensor to the XLA device: t = t.to(torch_xla.device())
  2. Shard dim 0 of t across the batch dimension and replicate dim 1: xs.mark_sharding(t, mesh, ("batch", None))
  3. Perform an all-gather operation across dim 0:
# Pair devices across batch rows
groups = [[0, 4], [1, 5], [2, 6], [3, 7]]
y = xm.all_gather(t, 0, groups=groups, pin_layout=False)
y = y.to("cpu")

The shape of the final y tensor is [16384, 784] where y[:8192] == y[8192:] == t. Is this the correct behaviour?

hshahTT avatar Jul 29 '25 19:07 hshahTT

Thank you for filing this issue. cc @bhavya01

ysiraichi avatar Jul 30 '25 17:07 ysiraichi