xla
xla copied to clipboard
TPU Pod support with PjRt
- Move TPU-specific logic from
pjrt.py
totpu.py
. All of this logic is broadly applicable to all TPU VMs and isn't strictly related to PjRt. - Update
configure_topology
to support multiple hosts. - Create unit tests for
tpu.py
(test_experimental_tpu.py
). Everything that requires a TPU is mocked out, so this can run on CPU. - Create short integration test for TPU (
test_experimental_pjrt_tpu.py
). This one initializes the TPU runtime in each test, so this must run on a TPU. Added configs for v3-8 and v4-8. Tested manually on both.
I know the naming of these two new tests is confusing. Let me know if you have alternate suggestions.
PJRT doesn't make a distinction between processes on different hosts, and we already support multiple processes on one host, so no low-level changes were necessary. This PR mainly deals with automatically configuring the TPU topology variables based on the environment.
Simple test case to sanity-check that collectives work as expected:
$ gcloud compute tpus tpu-vm ssh --project=tpu-pytorch --zone=us-central2-b wcromar-v4-32 --internal-ip --worker=all --command 'PJRT_DEVICE=TPU python3 -c "
import torch_xla.core.xla_model as xm
import torch_xla.experimental.pjrt as pjrt
import torch
def f():
ix = torch.ones([5], device=xm.xla_device()) * xm.get_ordinal()
return xm.all_reduce(xm.REDUCE_SUM, ix).cpu().numpy()
print(pjrt.run_multiprocess(f))
"'
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
{3: {0: array([120., 120., 120., 120., 120.], dtype=float32)}, 0: {0: array([120., 120., 120., 120., 120.], dtype=float32)}, 2: {0: array([120., 120., 120., 120., 120.], dtype=float32)}, 1: {0: array([120., 120., 120., 120., 120.], dtype=float32)}}
{0: {0: array([120., 120., 120., 120., 120.], dtype=float32)}, 3: {0: array([120., 120., 120., 120., 120.], dtype=float32)}, 1: {0: array([120., 120., 120., 120., 120.], dtype=float32)}, 2: {0: array([120., 120., 120., 120., 120.], dtype=float32)}}
{1: {0: array([120., 120., 120., 120., 120.], dtype=float32)}, 2: {0: array([120., 120., 120., 120., 120.], dtype=float32)}, 3: {0: array([120., 120., 120., 120., 120.], dtype=float32)}, 0: {0: array([120., 120., 120., 120., 120.], dtype=float32)}}
{2: {0: array([120., 120., 120., 120., 120.], dtype=float32)}, 0: {0: array([120., 120., 120., 120., 120.], dtype=float32)}, 3: {0: array([120., 120., 120., 120., 120.], dtype=float32)}, 1: {0: array([120., 120., 120., 120., 120.], dtype=float32)}}
On a v4-32, there are 16 total workers numbered [0, 16)
, so the expected result is sum(range(16)) = 120
.
Also confirmed that our ResNet50 example works on v4-16 and v4-32 and gets about the same per-chip performance as on v4-8 with fake data.
Really looking forward to this PR!
Simple test case to sanity-check that collectives work as expected:
It would be great to also resolve the reduce_scatter
, all_gather
, and all_to_all
collective ops in PJRT 😃
Currently all_reduce
works well on v3-8 after https://github.com/pytorch/xla/pull/3704, but reduce_scatter
still doesn't work well, and the all_gather
under "pin_layout=False" doesn't work well (all_gather
under "pin_layout=True" is actually using all_reduce
).
For example, reduce_scatter seems to do the "reduce" part correctly, but the "scatter" part has a rank mismatching the rank from "xm.get_ordinal", as shown in the case below when running on a v3-8 TPU VM:
import torch
import torch_xla.core.xla_model as xm
from torch_xla.experimental import pjrt
def test_pjrt_collective_ops():
rank = xm.get_ordinal()
world_size = xm.xrt_world_size()
device = xm.xla_device()
t = torch.arange(16, dtype=torch.float32).view(8, 2)
t = t.to(device)
xm.mark_step()
pin_layout = False
reduce_scatter_out = xm.reduce_scatter(
xm.REDUCE_SUM, t, scale=1.0 / world_size, scatter_dim=0, shard_count=world_size, pin_layout=pin_layout)
xm.mark_step()
print(f"rank {rank} of {world_size}:\nreduce_scatter_out (pin_layout={pin_layout}: {reduce_scatter_out}\n", end="", flush=True)
xm.rendezvous(f"reduce_scatter_out pin_layout={pin_layout}")
pin_layout = True
reduce_scatter_out = xm.reduce_scatter(
xm.REDUCE_SUM, t, scale=1.0 / world_size, scatter_dim=0, shard_count=world_size, pin_layout=pin_layout)
xm.mark_step()
print(f"rank {rank} of {world_size}:\nreduce_scatter_out (pin_layout={pin_layout}: {reduce_scatter_out}\n", end="", flush=True)
xm.rendezvous(f"reduce_scatter_out pin_layout={pin_layout}")
return 0.
if __name__ == '__main__':
pjrt.run_multiprocess(test_pjrt_collective_ops)
which prints
rank 0 of 8:
reduce_scatter_out (pin_layout=False: tensor([[0., 1.]], device='xla:0')
rank 1 of 8:
reduce_scatter_out (pin_layout=False: tensor([[2., 3.]], device='xla:1')
rank 2 of 8:
reduce_scatter_out (pin_layout=False: tensor([[12., 13.]], device='xla:0')
rank 3 of 8:
reduce_scatter_out (pin_layout=False: tensor([[14., 15.]], device='xla:1')
rank 4 of 8:
reduce_scatter_out (pin_layout=False: tensor([[4., 5.]], device='xla:0')
rank 5 of 8:
reduce_scatter_out (pin_layout=False: tensor([[6., 7.]], device='xla:1')
rank 6 of 8:
reduce_scatter_out (pin_layout=False: tensor([[8., 9.]], device='xla:0')
rank 7 of 8:
reduce_scatter_out (pin_layout=False: tensor([[10., 11.]], device='xla:1')
rank 0 of 8:
reduce_scatter_out (pin_layout=True: tensor([[0., 1.]], device='xla:0')
rank 1 of 8:
reduce_scatter_out (pin_layout=True: tensor([[2., 3.]], device='xla:1')
rank 2 of 8:
reduce_scatter_out (pin_layout=True: tensor([[12., 13.]], device='xla:0')
rank 3 of 8:
reduce_scatter_out (pin_layout=True: tensor([[14., 15.]], device='xla:1')
rank 4 of 8:
reduce_scatter_out (pin_layout=True: tensor([[4., 5.]], device='xla:0')
rank 5 of 8:
reduce_scatter_out (pin_layout=True: tensor([[6., 7.]], device='xla:1')
rank 6 of 8:
reduce_scatter_out (pin_layout=True: tensor([[8., 9.]], device='xla:0')
rank 7 of 8:
reduce_scatter_out (pin_layout=True: tensor([[10., 11.]], device='xla:1')
One can see that the reduce-scatter outputs have mismatched scatter results from the rank. Similarly, all_gather
under pin_layout=False
has the same problem.
(Besides, xm.rendezvous
is not working under PJRT yet -- it doesn't actually introduce a barrier across all ranks.)
@ronghanghu Thanks for flagging the issue with the other collectives. I did check all_gather
as well, but I didn't think to try with pin_layout=False
. This snippet gives the expected results:
def _mp_fn():
device = xm.xla_device()
ones = torch.ones((3), device=device) * xm.get_ordinal()
res = xm.all_gather(ones, pin_layout=True)
xm.mark_step()
print(xm.get_ordinal(), res)
pjrt.run_multiprocess(_mp_fn)
$ PJRT_DEVICE=TPU python all_gather.py
0 tensor([0., 0., 0., 1., 1., 1., 2., 2., 2., 3., 3., 3.], device='xla:0')
2 tensor([0., 0., 0., 1., 1., 1., 2., 2., 2., 3., 3., 3.], device='xla:0')
3 tensor([0., 0., 0., 1., 1., 1., 2., 2., 2., 3., 3., 3.], device='xla:0')
1 tensor([0., 0., 0., 1., 1., 1., 2., 2., 2., 3., 3., 3.], device='xla:0')
But with pin_layout=False
, I get this:
$ PJRT_DEVICE=TPU python all_gather.py
3 tensor([0., 0., 0., 3., 3., 3., 2., 2., 2., 1., 1., 1.], device='xla:0')
1 tensor([0., 0., 0., 3., 3., 3., 2., 2., 2., 1., 1., 1.], device='xla:0')
0 tensor([0., 0., 0., 3., 3., 3., 2., 2., 2., 1., 1., 1.], device='xla:0')
2 tensor([0., 0., 0., 3., 3., 3., 2., 2., 2., 1., 1., 1.], device='xla:0')
While working on the tests for this PR, I found that the TPU chips' device IDs don't actually match the device indices in TPU_VISIBLE_DEVICES
(which do correspond to xm.get_ordinal). For example, on a v4-8, I found that the PjRt device IDs are ordered ['TPU:0', 'TPU:2', 'TPU:3', 'TPU:1']
. I bet that is related. Although, I'm not entirely sure why pin_layout
would affect the results like this.
Can you file an issue assigned to me to look into reduce_scatter
, all_gather
, and all_to_all
?
Also,xm.rendezvous
doesn't work yet, but we had another early tester tell us that they were able to work around it by creating a gloo
process group and using dist.barrier
Thanks @will-cromar, I'll submit a new issue for all-gather, all-reduce, and all-to-all under PJRT.
I found that the PjRt device IDs are ordered ['TPU:0', 'TPU:2', 'TPU:3', 'TPU:1']. I bet that is related.
Yeah, I think this is the underlying cause of the all_gather and reduce_scatter difference. It would be great to make them consistent (otherwise, many existing programs that relies on xm.reduce_scatter
and other collective ops cannot work as expected).
Although, I'm not entirely sure why pin_layout would affect the results like this.
I think this is because xm.all_gather
only performs an actual all-gather under pin_layout=False
. Under pin_layout=True
, the API xm.all_gather
doesn't actually perform all-gather, but instead, it first pads the inputs with zeros and then performs all_reduce as in https://github.com/pytorch/xla/blob/ff2e0ea559b5f8785c7f577834af2b5af84b6849/torch_xla/core/xla_model.py#L673-L678, as introduced in https://github.com/pytorch/xla/pull/3568 and motivated by previous issues such as https://github.com/pytorch/xla/pull/3511
Also,
xm.rendezvous
doesn't work yet, but we had another early tester tell us that they were able to work around it by creating agloo
process group and usingdist.barrier
This is good to know! I guess this is probably a bit harder to do with threads on TPU v3, though.
barrier
will almost certainly not work with threads if you use the global default process group (i.e. use init_process_group
), because each thread will use the same PG. It might work if you call dist.new_group
in each thread and pass that group into barrier
directly.
There's an example in the docs of using nccl
as the default process group and manually initializing a separate gloo
process group to use with barrier
: https://pytorch.org/docs/stable/distributed.html#monitored-barrier
@will-cromar I created an issue in https://github.com/pytorch/xla/issues/3824 with a simple test example for all-gather, reduce-scatter, and all-to-all (but I cannot assign the issue to you since I don't have edit access to this repo 😅 )
@ronghanghu I will give you write access 😄
@ronghanghu I will give you write access 😄
Great, thank you!
Thanks @JackCaoG. I'll work on a README showing how to port from XRT to PjRt and how to run models without xla_dist
.