xla
xla copied to clipboard
Behaviour of xm.all_gather() in SPMD mode
❓ Questions and Help
I would like to confirm whether my MLIR compiler's handling of xm.all_gather() when running Torch-XLA in SPMD mode is correct.
Say I have the following:
- A tensor
twith shape [8192, 784] - A 2D named mesh
(batch, model)of 8 devices in a [2, 4] configuration:
Device Mesh:
0 1 2 3
4 5 6 7
Now I do the following steps:
- Move the tensor to the XLA device:
t = t.to(torch_xla.device()) - Shard dim 0 of t across the batch dimension and replicate dim 1:
xs.mark_sharding(t, mesh, ("batch", None)) - Perform an all-gather operation across dim 0:
# Pair devices across batch rows
groups = [[0, 4], [1, 5], [2, 6], [3, 7]]
y = xm.all_gather(t, 0, groups=groups, pin_layout=False)
y = y.to("cpu")
The shape of the final y tensor is [16384, 784] where y[:8192] == y[8192:] == t. Is this the correct behaviour?
Thank you for filing this issue. cc @bhavya01