xla icon indicating copy to clipboard operation
xla copied to clipboard

TPU Pod support with PjRt

Open will-cromar opened this issue 2 years ago • 10 comments

  • Move TPU-specific logic from pjrt.py to tpu.py. All of this logic is broadly applicable to all TPU VMs and isn't strictly related to PjRt.
  • Update configure_topology to support multiple hosts.
  • Create unit tests for tpu.py (test_experimental_tpu.py). Everything that requires a TPU is mocked out, so this can run on CPU.
  • Create short integration test for TPU (test_experimental_pjrt_tpu.py). This one initializes the TPU runtime in each test, so this must run on a TPU. Added configs for v3-8 and v4-8. Tested manually on both.

I know the naming of these two new tests is confusing. Let me know if you have alternate suggestions.

PJRT doesn't make a distinction between processes on different hosts, and we already support multiple processes on one host, so no low-level changes were necessary. This PR mainly deals with automatically configuring the TPU topology variables based on the environment.

will-cromar avatar Aug 01 '22 17:08 will-cromar

Simple test case to sanity-check that collectives work as expected:

$ gcloud compute tpus tpu-vm ssh --project=tpu-pytorch --zone=us-central2-b wcromar-v4-32 --internal-ip --worker=all --command 'PJRT_DEVICE=TPU python3 -c "
import torch_xla.core.xla_model as xm
import torch_xla.experimental.pjrt as pjrt
import torch
def f():
  ix = torch.ones([5], device=xm.xla_device()) * xm.get_ordinal()
  return xm.all_reduce(xm.REDUCE_SUM, ix).cpu().numpy()

print(pjrt.run_multiprocess(f))
"'
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
{3: {0: array([120., 120., 120., 120., 120.], dtype=float32)}, 0: {0: array([120., 120., 120., 120., 120.], dtype=float32)}, 2: {0: array([120., 120., 120., 120., 120.], dtype=float32)}, 1: {0: array([120., 120., 120., 120., 120.], dtype=float32)}}
{0: {0: array([120., 120., 120., 120., 120.], dtype=float32)}, 3: {0: array([120., 120., 120., 120., 120.], dtype=float32)}, 1: {0: array([120., 120., 120., 120., 120.], dtype=float32)}, 2: {0: array([120., 120., 120., 120., 120.], dtype=float32)}}
{1: {0: array([120., 120., 120., 120., 120.], dtype=float32)}, 2: {0: array([120., 120., 120., 120., 120.], dtype=float32)}, 3: {0: array([120., 120., 120., 120., 120.], dtype=float32)}, 0: {0: array([120., 120., 120., 120., 120.], dtype=float32)}}
{2: {0: array([120., 120., 120., 120., 120.], dtype=float32)}, 0: {0: array([120., 120., 120., 120., 120.], dtype=float32)}, 3: {0: array([120., 120., 120., 120., 120.], dtype=float32)}, 1: {0: array([120., 120., 120., 120., 120.], dtype=float32)}}

On a v4-32, there are 16 total workers numbered [0, 16), so the expected result is sum(range(16)) = 120.

Also confirmed that our ResNet50 example works on v4-16 and v4-32 and gets about the same per-chip performance as on v4-8 with fake data.

will-cromar avatar Aug 02 '22 17:08 will-cromar

Really looking forward to this PR!

Simple test case to sanity-check that collectives work as expected:

It would be great to also resolve the reduce_scatter, all_gather, and all_to_all collective ops in PJRT 😃

Currently all_reduce works well on v3-8 after https://github.com/pytorch/xla/pull/3704, but reduce_scatter still doesn't work well, and the all_gather under "pin_layout=False" doesn't work well (all_gather under "pin_layout=True" is actually using all_reduce).


For example, reduce_scatter seems to do the "reduce" part correctly, but the "scatter" part has a rank mismatching the rank from "xm.get_ordinal", as shown in the case below when running on a v3-8 TPU VM:

import torch
import torch_xla.core.xla_model as xm
from torch_xla.experimental import pjrt


def test_pjrt_collective_ops():
  rank = xm.get_ordinal()
  world_size = xm.xrt_world_size()
  device = xm.xla_device()

  t = torch.arange(16, dtype=torch.float32).view(8, 2)
  t = t.to(device)
  xm.mark_step()

  pin_layout = False
  reduce_scatter_out = xm.reduce_scatter(
    xm.REDUCE_SUM, t, scale=1.0 / world_size, scatter_dim=0, shard_count=world_size, pin_layout=pin_layout)
  xm.mark_step()
  print(f"rank {rank} of {world_size}:\nreduce_scatter_out (pin_layout={pin_layout}: {reduce_scatter_out}\n", end="", flush=True)
  xm.rendezvous(f"reduce_scatter_out pin_layout={pin_layout}")

  pin_layout = True
  reduce_scatter_out = xm.reduce_scatter(
    xm.REDUCE_SUM, t, scale=1.0 / world_size, scatter_dim=0, shard_count=world_size, pin_layout=pin_layout)
  xm.mark_step()
  print(f"rank {rank} of {world_size}:\nreduce_scatter_out (pin_layout={pin_layout}: {reduce_scatter_out}\n", end="", flush=True)
  xm.rendezvous(f"reduce_scatter_out pin_layout={pin_layout}")

  return 0.


if __name__ == '__main__':
  pjrt.run_multiprocess(test_pjrt_collective_ops)

which prints


rank 0 of 8:
reduce_scatter_out (pin_layout=False: tensor([[0., 1.]], device='xla:0')
rank 1 of 8:
reduce_scatter_out (pin_layout=False: tensor([[2., 3.]], device='xla:1')
rank 2 of 8:
reduce_scatter_out (pin_layout=False: tensor([[12., 13.]], device='xla:0')
rank 3 of 8:
reduce_scatter_out (pin_layout=False: tensor([[14., 15.]], device='xla:1')
rank 4 of 8:
reduce_scatter_out (pin_layout=False: tensor([[4., 5.]], device='xla:0')
rank 5 of 8:
reduce_scatter_out (pin_layout=False: tensor([[6., 7.]], device='xla:1')
rank 6 of 8:
reduce_scatter_out (pin_layout=False: tensor([[8., 9.]], device='xla:0')
rank 7 of 8:
reduce_scatter_out (pin_layout=False: tensor([[10., 11.]], device='xla:1')

rank 0 of 8:
reduce_scatter_out (pin_layout=True: tensor([[0., 1.]], device='xla:0')
rank 1 of 8:
reduce_scatter_out (pin_layout=True: tensor([[2., 3.]], device='xla:1')
rank 2 of 8:
reduce_scatter_out (pin_layout=True: tensor([[12., 13.]], device='xla:0')
rank 3 of 8:
reduce_scatter_out (pin_layout=True: tensor([[14., 15.]], device='xla:1')
rank 4 of 8:
reduce_scatter_out (pin_layout=True: tensor([[4., 5.]], device='xla:0')
rank 5 of 8:
reduce_scatter_out (pin_layout=True: tensor([[6., 7.]], device='xla:1')
rank 6 of 8:
reduce_scatter_out (pin_layout=True: tensor([[8., 9.]], device='xla:0')
rank 7 of 8:
reduce_scatter_out (pin_layout=True: tensor([[10., 11.]], device='xla:1')

One can see that the reduce-scatter outputs have mismatched scatter results from the rank. Similarly, all_gather under pin_layout=False has the same problem.

(Besides, xm.rendezvous is not working under PJRT yet -- it doesn't actually introduce a barrier across all ranks.)

ronghanghu avatar Aug 03 '22 16:08 ronghanghu

@ronghanghu Thanks for flagging the issue with the other collectives. I did check all_gather as well, but I didn't think to try with pin_layout=False. This snippet gives the expected results:

def _mp_fn():
  device = xm.xla_device()

  ones = torch.ones((3), device=device) * xm.get_ordinal()

  res = xm.all_gather(ones, pin_layout=True)
  xm.mark_step()

  print(xm.get_ordinal(), res)


pjrt.run_multiprocess(_mp_fn)
$ PJRT_DEVICE=TPU python all_gather.py
0 tensor([0., 0., 0., 1., 1., 1., 2., 2., 2., 3., 3., 3.], device='xla:0')
2 tensor([0., 0., 0., 1., 1., 1., 2., 2., 2., 3., 3., 3.], device='xla:0')
3 tensor([0., 0., 0., 1., 1., 1., 2., 2., 2., 3., 3., 3.], device='xla:0')
1 tensor([0., 0., 0., 1., 1., 1., 2., 2., 2., 3., 3., 3.], device='xla:0')

But with pin_layout=False, I get this:

$ PJRT_DEVICE=TPU python all_gather.py
3 tensor([0., 0., 0., 3., 3., 3., 2., 2., 2., 1., 1., 1.], device='xla:0')
1 tensor([0., 0., 0., 3., 3., 3., 2., 2., 2., 1., 1., 1.], device='xla:0')
0 tensor([0., 0., 0., 3., 3., 3., 2., 2., 2., 1., 1., 1.], device='xla:0')
2 tensor([0., 0., 0., 3., 3., 3., 2., 2., 2., 1., 1., 1.], device='xla:0')

While working on the tests for this PR, I found that the TPU chips' device IDs don't actually match the device indices in TPU_VISIBLE_DEVICES (which do correspond to xm.get_ordinal). For example, on a v4-8, I found that the PjRt device IDs are ordered ['TPU:0', 'TPU:2', 'TPU:3', 'TPU:1']. I bet that is related. Although, I'm not entirely sure why pin_layout would affect the results like this.

Can you file an issue assigned to me to look into reduce_scatter, all_gather, and all_to_all?

will-cromar avatar Aug 03 '22 17:08 will-cromar

Also,xm.rendezvous doesn't work yet, but we had another early tester tell us that they were able to work around it by creating a gloo process group and using dist.barrier

will-cromar avatar Aug 03 '22 17:08 will-cromar

Thanks @will-cromar, I'll submit a new issue for all-gather, all-reduce, and all-to-all under PJRT.

I found that the PjRt device IDs are ordered ['TPU:0', 'TPU:2', 'TPU:3', 'TPU:1']. I bet that is related.

Yeah, I think this is the underlying cause of the all_gather and reduce_scatter difference. It would be great to make them consistent (otherwise, many existing programs that relies on xm.reduce_scatter and other collective ops cannot work as expected).

Although, I'm not entirely sure why pin_layout would affect the results like this.

I think this is because xm.all_gather only performs an actual all-gather under pin_layout=False. Under pin_layout=True, the API xm.all_gather doesn't actually perform all-gather, but instead, it first pads the inputs with zeros and then performs all_reduce as in https://github.com/pytorch/xla/blob/ff2e0ea559b5f8785c7f577834af2b5af84b6849/torch_xla/core/xla_model.py#L673-L678, as introduced in https://github.com/pytorch/xla/pull/3568 and motivated by previous issues such as https://github.com/pytorch/xla/pull/3511

ronghanghu avatar Aug 03 '22 17:08 ronghanghu

Also,xm.rendezvous doesn't work yet, but we had another early tester tell us that they were able to work around it by creating a gloo process group and using dist.barrier

This is good to know! I guess this is probably a bit harder to do with threads on TPU v3, though.

ronghanghu avatar Aug 03 '22 17:08 ronghanghu

barrier will almost certainly not work with threads if you use the global default process group (i.e. use init_process_group), because each thread will use the same PG. It might work if you call dist.new_group in each thread and pass that group into barrier directly.

There's an example in the docs of using nccl as the default process group and manually initializing a separate gloo process group to use with barrier: https://pytorch.org/docs/stable/distributed.html#monitored-barrier

will-cromar avatar Aug 03 '22 17:08 will-cromar

@will-cromar I created an issue in https://github.com/pytorch/xla/issues/3824 with a simple test example for all-gather, reduce-scatter, and all-to-all (but I cannot assign the issue to you since I don't have edit access to this repo 😅 )

ronghanghu avatar Aug 03 '22 21:08 ronghanghu

@ronghanghu I will give you write access 😄

JackCaoG avatar Aug 05 '22 17:08 JackCaoG

@ronghanghu I will give you write access 😄

Great, thank you!

ronghanghu avatar Aug 05 '22 17:08 ronghanghu

Thanks @JackCaoG. I'll work on a README showing how to port from XRT to PjRt and how to run models without xla_dist.

will-cromar avatar Aug 17 '22 21:08 will-cromar