xla
xla copied to clipboard
Adding support for torchrun in xla backend
This pull request enables the code needed to integrate torchrun launcher with xla backend.
@amithrm Can you provide a test for this new feature?
@JackCaoG sure..will add tests
@JackCaoG I changed the initialization a bit to take into account how slurm configures the devices. Please take a look at it and also the test cases. All of these need would need more modifications after we discuss
I have to admit that I am not an expert of torchrun, let me read up some documentations first lol. Looping in @will-cromar to make sure this does not conflict with our future pjrt runtime.
we did some internal testing. It appears that at scale, we see issues with the set up of GRPC channels. We should understand if you see similar issues at your end too.
@JackCaoG A simple test that you can run on GPU-XLA:
import sys
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import os
def _mp_fn(index):
print('XRT_LOCAL_WORKER:{}'.format(os.environ['XRT_LOCAL_WORKER']))
print('XRT_DEVICE_MAP:{}'.format(os.environ['XRT_DEVICE_MAP']))
print('XRT_WORKERS:{}'.format(os.environ['XRT_WORKERS']))
print('XRT_HOST_WORLD_SIZE:{}'.format(os.environ['XRT_HOST_WORLD_SIZE']))
device = xm.xla_device()
world_size = xm.xrt_world_size()
ordinal_tensor = torch.tensor([index], dtype=torch.float).to(device)
print('rank:{}, value:{}'.format(index, ordinal_tensor))
result = xm.all_reduce('sum', ordinal_tensor)
cpu_result = result.cpu()
print('rank:{}, value:{}'.format(index, cpu_result))
if __name__ == '__main__':
xmp.spawn(_mp_fn, args=(), nprocs=2, join=True)
Run command:
GPU_NUM_DEVICES=2 python3 allreduce_xla.py
This will output:
XRT_LOCAL_WORKER:localservice:0 XRT_DEVICE_MAP:GPU:0;/job:localservice/replica:0/task:0/device:XLA_GPU:0|GPU:1;/job:localservice/replica:0/task:1/device:XLA_GPU:0 XRT_WORKERS:localservice:0;grpc://dfda805bbe4b:49887|localservice:1;grpc://dfda805bbe4b:33097 XRT_LOCAL_WORKER:localservice:1 XRT_DEVICE_MAP:GPU:0;/job:localservice/replica:0/task:0/device:XLA_GPU:0|GPU:1;/job:localservice/replica:0/task:1/device:XLA_GPU:0 XRT_WORKERS:localservice:0;grpc://dfda805bbe4b:49887|localservice:1;grpc://dfda805bbe4b:33097
If you look for XRT_WORKERS, this has the grpc string for each worker. This won't scale with number of workers.
We want to support torch run with new PJRT run time, in the mean time if this torch run utility can unblock aws folks we can also take it.
I am a bit hesitant whether claim official support for XRT:TPU + torch run. @will-cromar Let's invesgate what's the gap here, if it is free we might as well just take it.
in terms of moving the code to experimental I guess it is related to how do we want to set the user expection. Is there any known cavet for this feature? If it works well we don't need to put it in experimental. However we should add a README similar to https://github.com/pytorch/xla/blob/master/docs/ddp.md
@will-cromar can you take another pass of this pr when you have some time?
Looks like the file need to test (allreduce_torchrun.py) is not getting picked up. Checking with @will-cromar on how to fix this. And some yapf fixes are pending in one file.
@will-cromar I see build failure: NameError: name 'sympy' is not defined
weird, head is green right now https://github.com/pytorch/xla/commits/master
Ah ok https://github.com/pytorch/xla/pull/4313/files should fix it, can you rebase again?
@JackCaoG Alll the 4 pass @will-cromar is there anything else needed?
Thanks @JackCaoG and @amithrm !