xla icon indicating copy to clipboard operation
xla copied to clipboard

Mismatched rank in collective ops (all-gather, reduce-scatter, and all-to-all) in PJRT runtime

Open ronghanghu opened this issue 2 years ago • 0 comments

🐛 Bug

For the PJRT runtime, currently xm.all_reduce works well on v3-8 after https://github.com/pytorch/xla/pull/3704 is merged, but xm.reduce_scatter and xm.all_to_all still still do not work well, and the xm.all_gather under "pin_layout=False" doesn't work well (all_gather under "pin_layout=True" is actually using all_reduce).

This issue is submitted based on the discussions in https://github.com/pytorch/xla/pull/3813#issuecomment-1204215319.

To Reproduce

The example below contains a test case for each collective op (all_gather, reduce_scatter, and all_to_all), which doesn't fully work under PJRT for now.

  1. Allocate a new v3-8 TPU VM with tpu-vm-pt-1.12 runtime and install the nightly PyTorch/XLA wheels that contain https://github.com/pytorch/xla/pull/3704
# torch, torchvision and torch_xla
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch-nightly+20220802-cp38-cp38-linux_x86_64.whl
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torchvision-nightly+20220802-cp38-cp38-linux_x86_64.whl
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-nightly+20220802-cp38-cp38-linux_x86_64.whl

# libtpu
sudo pip3 install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20220518-py3-none-any.whl
  1. Save the following content to a file test_pjrt_collective_ops.py
import torch
import torch_xla.core.xla_model as xm
from torch_xla.experimental import pjrt


def test_pjrt_collective_ops():
    rank = xm.get_ordinal()
    world_size = xm.xrt_world_size()
    device = xm.xla_device()
    pin_layout = False

    xm.mark_step()

    # all-gather -- expected output for rank i on v3-8:
    #   [0, 1, 2, 3, 4, 5, 6, 7]
    t1 = (torch.ones(1) * rank).to(device)
    all_gather_out = xm.all_gather(t1, dim=0, pin_layout=pin_layout)

    # reduce-scatter -- expected output for rank i on v3-8:
    #   [-i]
    t2 = -torch.arange(world_size, dtype=torch.float32).to(device)
    reduce_scatter_out = xm.reduce_scatter(
        xm.REDUCE_SUM,
        t2,
        scale=1.0 / world_size,
        scatter_dim=0,
        shard_count=world_size,
        pin_layout=pin_layout,
    )

    # all-to-all -- expected output for rank i on v3-8:
    #   [[[-i, -i, -i, -i, -i, -i, -i, -i],
    #       0,  1,  2,  3,  4,  5,  6,  7]]]
    t3 = torch.cat(
        [
            -torch.arange(world_size, dtype=torch.float32).view(-1, 1, 1),
            torch.ones(world_size, 1, 1) * rank,
        ],
        dim=1,
    ).to(device)
    all_to_all_out = xm.all_to_all(
        t3,
        split_dimension=0,
        concat_dimension=2,
        split_count=world_size,
        pin_layout=pin_layout,
    )

    xm.mark_step()
    print(
        f"\n[rank {rank} of {world_size} (pin_layout={pin_layout})]\n"
        f"* all_gather_out:\n{all_gather_out}\n"
        f"* reduce_scatter_out:\n{reduce_scatter_out}\n"
        f"* all_to_all_out:\n{all_to_all_out}\n",
        end="",
        flush=True,
    )


if __name__ == "__main__":
    pjrt.run_multiprocess(test_pjrt_collective_ops)
  1. Run this file under PJRT via PJRT_DEVICE=TPU python3 test_pjrt_collective_ops.py, which shows that these 3 collective ops have problematic outputs that do not match their rank from xm.get_ordinal(), printing the following
[rank 0 of 8 (pin_layout=False)]
* all_gather_out:
tensor([0., 1., 4., 5., 6., 7., 2., 3.], device='xla:0')
* reduce_scatter_out:
tensor([0.], device='xla:0')
* all_to_all_out:
tensor([[[0., -0., -0., -0., -0., -0., -0., -0.],
         [0., 1., 4., 5., 6., 7., 2., 3.]]], device='xla:0')

[rank 1 of 8 (pin_layout=False)]
* all_gather_out:
tensor([0., 1., 4., 5., 6., 7., 2., 3.], device='xla:1')
* reduce_scatter_out:
tensor([-1.], device='xla:1')
* all_to_all_out:
tensor([[[-1., -1., -1., -1., -1., -1., -1., -1.],
         [ 0.,  1.,  4.,  5.,  6.,  7.,  2.,  3.]]], device='xla:1')

[rank 2 of 8 (pin_layout=False)]
* all_gather_out:
tensor([0., 1., 4., 5., 6., 7., 2., 3.], device='xla:0')
* reduce_scatter_out:
tensor([-6.], device='xla:0')
* all_to_all_out:
tensor([[[-6., -6., -6., -6., -6., -6., -6., -6.],
         [ 0.,  1.,  4.,  5.,  6.,  7.,  2.,  3.]]], device='xla:0')

[rank 3 of 8 (pin_layout=False)]
* all_gather_out:
tensor([0., 1., 4., 5., 6., 7., 2., 3.], device='xla:1')
* reduce_scatter_out:
tensor([-7.], device='xla:1')
* all_to_all_out:
tensor([[[-7., -7., -7., -7., -7., -7., -7., -7.],
         [ 0.,  1.,  4.,  5.,  6.,  7.,  2.,  3.]]], device='xla:1')

[rank 4 of 8 (pin_layout=False)]
* all_gather_out:
tensor([0., 1., 4., 5., 6., 7., 2., 3.], device='xla:0')
* reduce_scatter_out:
tensor([-2.], device='xla:0')
* all_to_all_out:
tensor([[[-2., -2., -2., -2., -2., -2., -2., -2.],
         [ 0.,  1.,  4.,  5.,  6.,  7.,  2.,  3.]]], device='xla:0')

[rank 5 of 8 (pin_layout=False)]
* all_gather_out:
tensor([0., 1., 4., 5., 6., 7., 2., 3.], device='xla:1')
* reduce_scatter_out:
tensor([-3.], device='xla:1')
* all_to_all_out:
tensor([[[-3., -3., -3., -3., -3., -3., -3., -3.],
         [ 0.,  1.,  4.,  5.,  6.,  7.,  2.,  3.]]], device='xla:1')

[rank 6 of 8 (pin_layout=False)]
* all_gather_out:
tensor([0., 1., 4., 5., 6., 7., 2., 3.], device='xla:0')
* reduce_scatter_out:
tensor([-4.], device='xla:0')
* all_to_all_out:
tensor([[[-4., -4., -4., -4., -4., -4., -4., -4.],
         [ 0.,  1.,  4.,  5.,  6.,  7.,  2.,  3.]]], device='xla:0')

[rank 7 of 8 (pin_layout=False)]
* all_gather_out:
tensor([0., 1., 4., 5., 6., 7., 2., 3.], device='xla:1')
* reduce_scatter_out:
tensor([-5.], device='xla:1')
* all_to_all_out:
tensor([[[-5., -5., -5., -5., -5., -5., -5., -5.],
         [ 0.,  1.,  4.,  5.,  6.,  7.,  2.,  3.]]], device='xla:1')

For example, in the output above, reduce_scatter seems to do the "reduce" part correctly, but the "scatter" part has a rank mismatching the rank from "xm.get_ordinal".

Expected behavior

For each rank i (in 0, ..., 7 on v3-8), it should have the following output:

all-gather -- expected output for rank i on v3-8:
   [0, 1, 2, 3, 4, 5, 6, 7]
reduce-scatter -- expected output for rank i on v3-8:
   [-i]
all-to-all -- expected output for rank i on v3-8:
   [[[-i, -i, -i, -i, -i, -i, -i, -i],
       0,  1,  2,  3,  4,  5,  6,  7]]]

Environment

  • Reproducible on XLA backend [CPU/TPU]: v3-8 TPU VM with tpu-vm-pt-1.12 runtime
  • torch_xla version: nightly 20220802 (see above)

Additional context

  • Another problem is that xm.rendezvous is not working under PJRT yet -- it doesn't actually introduce a barrier across all ranks. I think once these issues are resolved (and https://github.com/pytorch/xla/pull/3813 is merged), it should be possible to port most XRT codebases to PJRT.
  • The xm.all_reduce op seems to work well under PJRT, perhaps because it's not affected by the rank order.
  • The examples above use pin_layout=False, but they should be mostly the same if one switches to pin_layout=True (note that xm.all_gather only performs an actual all-gather under pin_layout=False. Under pin_layout=True, the API xm.all_gather doesn't actually perform all-gather, but instead, it first pads the inputs with zeros and then performs all_reduce).
  • Also, it would be great to tuplify the input during compilation and execution to get rid of the size limit in https://github.com/pytorch/xla/issues/3453#issuecomment-1121490148 and allow training larger models in PyTorch/XLA.

cc: @will-cromar

ronghanghu avatar Aug 03 '22 21:08 ronghanghu