xla
xla copied to clipboard
Mismatched rank in collective ops (all-gather, reduce-scatter, and all-to-all) in PJRT runtime
🐛 Bug
For the PJRT runtime, currently xm.all_reduce
works well on v3-8 after https://github.com/pytorch/xla/pull/3704 is merged, but xm.reduce_scatter
and xm.all_to_all
still still do not work well, and the xm.all_gather
under "pin_layout=False" doesn't work well (all_gather
under "pin_layout=True" is actually using all_reduce
).
This issue is submitted based on the discussions in https://github.com/pytorch/xla/pull/3813#issuecomment-1204215319.
To Reproduce
The example below contains a test case for each collective op (all_gather, reduce_scatter, and all_to_all), which doesn't fully work under PJRT for now.
- Allocate a new v3-8 TPU VM with
tpu-vm-pt-1.12
runtime and install the nightly PyTorch/XLA wheels that contain https://github.com/pytorch/xla/pull/3704
# torch, torchvision and torch_xla
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch-nightly+20220802-cp38-cp38-linux_x86_64.whl
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torchvision-nightly+20220802-cp38-cp38-linux_x86_64.whl
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-nightly+20220802-cp38-cp38-linux_x86_64.whl
# libtpu
sudo pip3 install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20220518-py3-none-any.whl
- Save the following content to a file
test_pjrt_collective_ops.py
import torch
import torch_xla.core.xla_model as xm
from torch_xla.experimental import pjrt
def test_pjrt_collective_ops():
rank = xm.get_ordinal()
world_size = xm.xrt_world_size()
device = xm.xla_device()
pin_layout = False
xm.mark_step()
# all-gather -- expected output for rank i on v3-8:
# [0, 1, 2, 3, 4, 5, 6, 7]
t1 = (torch.ones(1) * rank).to(device)
all_gather_out = xm.all_gather(t1, dim=0, pin_layout=pin_layout)
# reduce-scatter -- expected output for rank i on v3-8:
# [-i]
t2 = -torch.arange(world_size, dtype=torch.float32).to(device)
reduce_scatter_out = xm.reduce_scatter(
xm.REDUCE_SUM,
t2,
scale=1.0 / world_size,
scatter_dim=0,
shard_count=world_size,
pin_layout=pin_layout,
)
# all-to-all -- expected output for rank i on v3-8:
# [[[-i, -i, -i, -i, -i, -i, -i, -i],
# 0, 1, 2, 3, 4, 5, 6, 7]]]
t3 = torch.cat(
[
-torch.arange(world_size, dtype=torch.float32).view(-1, 1, 1),
torch.ones(world_size, 1, 1) * rank,
],
dim=1,
).to(device)
all_to_all_out = xm.all_to_all(
t3,
split_dimension=0,
concat_dimension=2,
split_count=world_size,
pin_layout=pin_layout,
)
xm.mark_step()
print(
f"\n[rank {rank} of {world_size} (pin_layout={pin_layout})]\n"
f"* all_gather_out:\n{all_gather_out}\n"
f"* reduce_scatter_out:\n{reduce_scatter_out}\n"
f"* all_to_all_out:\n{all_to_all_out}\n",
end="",
flush=True,
)
if __name__ == "__main__":
pjrt.run_multiprocess(test_pjrt_collective_ops)
- Run this file under PJRT via
PJRT_DEVICE=TPU python3 test_pjrt_collective_ops.py
, which shows that these 3 collective ops have problematic outputs that do not match their rank fromxm.get_ordinal()
, printing the following
[rank 0 of 8 (pin_layout=False)]
* all_gather_out:
tensor([0., 1., 4., 5., 6., 7., 2., 3.], device='xla:0')
* reduce_scatter_out:
tensor([0.], device='xla:0')
* all_to_all_out:
tensor([[[0., -0., -0., -0., -0., -0., -0., -0.],
[0., 1., 4., 5., 6., 7., 2., 3.]]], device='xla:0')
[rank 1 of 8 (pin_layout=False)]
* all_gather_out:
tensor([0., 1., 4., 5., 6., 7., 2., 3.], device='xla:1')
* reduce_scatter_out:
tensor([-1.], device='xla:1')
* all_to_all_out:
tensor([[[-1., -1., -1., -1., -1., -1., -1., -1.],
[ 0., 1., 4., 5., 6., 7., 2., 3.]]], device='xla:1')
[rank 2 of 8 (pin_layout=False)]
* all_gather_out:
tensor([0., 1., 4., 5., 6., 7., 2., 3.], device='xla:0')
* reduce_scatter_out:
tensor([-6.], device='xla:0')
* all_to_all_out:
tensor([[[-6., -6., -6., -6., -6., -6., -6., -6.],
[ 0., 1., 4., 5., 6., 7., 2., 3.]]], device='xla:0')
[rank 3 of 8 (pin_layout=False)]
* all_gather_out:
tensor([0., 1., 4., 5., 6., 7., 2., 3.], device='xla:1')
* reduce_scatter_out:
tensor([-7.], device='xla:1')
* all_to_all_out:
tensor([[[-7., -7., -7., -7., -7., -7., -7., -7.],
[ 0., 1., 4., 5., 6., 7., 2., 3.]]], device='xla:1')
[rank 4 of 8 (pin_layout=False)]
* all_gather_out:
tensor([0., 1., 4., 5., 6., 7., 2., 3.], device='xla:0')
* reduce_scatter_out:
tensor([-2.], device='xla:0')
* all_to_all_out:
tensor([[[-2., -2., -2., -2., -2., -2., -2., -2.],
[ 0., 1., 4., 5., 6., 7., 2., 3.]]], device='xla:0')
[rank 5 of 8 (pin_layout=False)]
* all_gather_out:
tensor([0., 1., 4., 5., 6., 7., 2., 3.], device='xla:1')
* reduce_scatter_out:
tensor([-3.], device='xla:1')
* all_to_all_out:
tensor([[[-3., -3., -3., -3., -3., -3., -3., -3.],
[ 0., 1., 4., 5., 6., 7., 2., 3.]]], device='xla:1')
[rank 6 of 8 (pin_layout=False)]
* all_gather_out:
tensor([0., 1., 4., 5., 6., 7., 2., 3.], device='xla:0')
* reduce_scatter_out:
tensor([-4.], device='xla:0')
* all_to_all_out:
tensor([[[-4., -4., -4., -4., -4., -4., -4., -4.],
[ 0., 1., 4., 5., 6., 7., 2., 3.]]], device='xla:0')
[rank 7 of 8 (pin_layout=False)]
* all_gather_out:
tensor([0., 1., 4., 5., 6., 7., 2., 3.], device='xla:1')
* reduce_scatter_out:
tensor([-5.], device='xla:1')
* all_to_all_out:
tensor([[[-5., -5., -5., -5., -5., -5., -5., -5.],
[ 0., 1., 4., 5., 6., 7., 2., 3.]]], device='xla:1')
For example, in the output above, reduce_scatter
seems to do the "reduce" part correctly, but the "scatter" part has a rank mismatching the rank from "xm.get_ordinal".
Expected behavior
For each rank i
(in 0, ..., 7 on v3-8), it should have the following output:
all-gather -- expected output for rank i on v3-8:
[0, 1, 2, 3, 4, 5, 6, 7]
reduce-scatter -- expected output for rank i on v3-8:
[-i]
all-to-all -- expected output for rank i on v3-8:
[[[-i, -i, -i, -i, -i, -i, -i, -i],
0, 1, 2, 3, 4, 5, 6, 7]]]
Environment
- Reproducible on XLA backend [CPU/TPU]: v3-8 TPU VM with
tpu-vm-pt-1.12
runtime - torch_xla version: nightly 20220802 (see above)
Additional context
- Another problem is that
xm.rendezvous
is not working under PJRT yet -- it doesn't actually introduce a barrier across all ranks. I think once these issues are resolved (and https://github.com/pytorch/xla/pull/3813 is merged), it should be possible to port most XRT codebases to PJRT. - The
xm.all_reduce
op seems to work well under PJRT, perhaps because it's not affected by the rank order. - The examples above use
pin_layout=False
, but they should be mostly the same if one switches topin_layout=True
(note thatxm.all_gather
only performs an actual all-gather underpin_layout=False
. Underpin_layout=True
, the APIxm.all_gather
doesn't actually perform all-gather, but instead, it first pads the inputs with zeros and then performs all_reduce). - Also, it would be great to tuplify the input during compilation and execution to get rid of the size limit in https://github.com/pytorch/xla/issues/3453#issuecomment-1121490148 and allow training larger models in PyTorch/XLA.
cc: @will-cromar