xla icon indicating copy to clipboard operation
xla copied to clipboard

Adding support for torchrun in xla backend

Open amithrm opened this issue 2 years ago • 7 comments

This pull request enables the code needed to integrate torchrun launcher with xla backend.

amithrm avatar May 25 '22 05:05 amithrm

@amithrm Can you provide a test for this new feature?

JackCaoG avatar May 25 '22 18:05 JackCaoG

@JackCaoG sure..will add tests

amithrm avatar May 25 '22 19:05 amithrm

@JackCaoG I changed the initialization a bit to take into account how slurm configures the devices. Please take a look at it and also the test cases. All of these need would need more modifications after we discuss

amithrm avatar Jun 08 '22 19:06 amithrm

I have to admit that I am not an expert of torchrun, let me read up some documentations first lol. Looping in @will-cromar to make sure this does not conflict with our future pjrt runtime.

JackCaoG avatar Jun 09 '22 23:06 JackCaoG

we did some internal testing. It appears that at scale, we see issues with the set up of GRPC channels. We should understand if you see similar issues at your end too.

amithrm avatar Jun 22 '22 22:06 amithrm

@JackCaoG A simple test that you can run on GPU-XLA:

import sys
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp

import os


def _mp_fn(index):
  print('XRT_LOCAL_WORKER:{}'.format(os.environ['XRT_LOCAL_WORKER']))
  print('XRT_DEVICE_MAP:{}'.format(os.environ['XRT_DEVICE_MAP']))
  print('XRT_WORKERS:{}'.format(os.environ['XRT_WORKERS']))
  print('XRT_HOST_WORLD_SIZE:{}'.format(os.environ['XRT_HOST_WORLD_SIZE']))
  device = xm.xla_device()
  world_size = xm.xrt_world_size()
  ordinal_tensor = torch.tensor([index], dtype=torch.float).to(device)
  print('rank:{}, value:{}'.format(index, ordinal_tensor))
  result = xm.all_reduce('sum', ordinal_tensor)

  cpu_result = result.cpu()
  print('rank:{}, value:{}'.format(index, cpu_result))


if __name__ == '__main__':
  xmp.spawn(_mp_fn, args=(), nprocs=2, join=True)

amithrm avatar Jun 23 '22 20:06 amithrm

Run command:

GPU_NUM_DEVICES=2 python3 allreduce_xla.py

This will output:

XRT_LOCAL_WORKER:localservice:0 XRT_DEVICE_MAP:GPU:0;/job:localservice/replica:0/task:0/device:XLA_GPU:0|GPU:1;/job:localservice/replica:0/task:1/device:XLA_GPU:0 XRT_WORKERS:localservice:0;grpc://dfda805bbe4b:49887|localservice:1;grpc://dfda805bbe4b:33097 XRT_LOCAL_WORKER:localservice:1 XRT_DEVICE_MAP:GPU:0;/job:localservice/replica:0/task:0/device:XLA_GPU:0|GPU:1;/job:localservice/replica:0/task:1/device:XLA_GPU:0 XRT_WORKERS:localservice:0;grpc://dfda805bbe4b:49887|localservice:1;grpc://dfda805bbe4b:33097

If you look for XRT_WORKERS, this has the grpc string for each worker. This won't scale with number of workers.

amithrm avatar Jun 23 '22 20:06 amithrm

We want to support torch run with new PJRT run time, in the mean time if this torch run utility can unblock aws folks we can also take it.

I am a bit hesitant whether claim official support for XRT:TPU + torch run. @will-cromar Let's invesgate what's the gap here, if it is free we might as well just take it.

JackCaoG avatar Nov 02 '22 18:11 JackCaoG

in terms of moving the code to experimental I guess it is related to how do we want to set the user expection. Is there any known cavet for this feature? If it works well we don't need to put it in experimental. However we should add a README similar to https://github.com/pytorch/xla/blob/master/docs/ddp.md

JackCaoG avatar Nov 02 '22 22:11 JackCaoG

@will-cromar can you take another pass of this pr when you have some time?

JackCaoG avatar Dec 05 '22 18:12 JackCaoG

Looks like the file need to test (allreduce_torchrun.py) is not getting picked up. Checking with @will-cromar on how to fix this. And some yapf fixes are pending in one file.

amithrm avatar Dec 08 '22 20:12 amithrm

@will-cromar I see build failure: NameError: name 'sympy' is not defined

amithrm avatar Dec 13 '22 00:12 amithrm

weird, head is green right now https://github.com/pytorch/xla/commits/master

JackCaoG avatar Dec 13 '22 00:12 JackCaoG

Ah ok https://github.com/pytorch/xla/pull/4313/files should fix it, can you rebase again?

JackCaoG avatar Dec 13 '22 00:12 JackCaoG

@JackCaoG Alll the 4 pass @will-cromar is there anything else needed?

amithrm avatar Dec 13 '22 18:12 amithrm

Thanks @JackCaoG and @amithrm !

jeffhataws avatar Dec 13 '22 23:12 jeffhataws