xla icon indicating copy to clipboard operation
xla copied to clipboard

Cannot run 7-node PT distributed test

Open aws-rhsoln opened this issue 2 years ago • 11 comments

🐛 Bug

The test works for any number of nodes between 1 and 6. Adding the 7th node results in failure. Note, this is not due to a bad node because a 6-node test works over any 6 out of 7 instances.

To Reproduce

We started with a simple all-reduce test:

def _mp_fn():
  world_size = xm.xrt_world_size()
  device = xm.xla_device()
  rank = xm.get_ordinal()
  ones = torch.ones((20, 30))
  xones = ones.to(device)
  if world_size > 0:
    print("running all reduce")
    for i in range(0, 2):
        print(f'at iteration {i}, with local rank {rank}', flush=True)
        result = xm.all_reduce(xm.REDUCE_SUM, xones)
        result_cpu = result.cpu()
        print(result_cpu, flush = True)

We then simplified it further and removed collective communication op, now while the test is executed via PT distributed the workers do not communicate with one another during the execution.

def _mp_fn():
  world_size = xm.xrt_world_size()
  device = xm.xla_device()
  rank = xm.get_ordinal()
  ones = torch.ones((20, 30))
  twos = torch.ones((20, 30))
  xones = ones.to(device)
  xtwos = twos.to(device)
  if world_size > 0:
    for i in range(0, 2):
        print(f'at iteration {i}, with local rank {rank}', flush=True)
        result = xones*xtwos
        result_cpu = result.cpu()
        print(result_cpu, flush = True)

The command to run the test:

python3 -m torch.distributed.launch --nproc_per_node=32 --nnodes=7 --node_rank=1 --master_addr=10.0.11.15 --master_port=33666 ~/all_reduce.py --enable_dist_launch

In both cases adding the 7th node results in what looks like a GRPC failure at bootstrap. Different nodes get slightly different errors:

022-06-08 20:09:09.215047: W tensorflow/core/distributed_runtime/master_session.cc:2122] Unavailable: failed to connect to all addresses
Additional GRPC error information from remote target /job:localservice/replica:0/task:49:
:{"created":"@1654718949.211034423","description":"Failed to pick subchannel","file":"external/com_github_grpc_grpc/src/core/ext/filters/client_channel/client_channel.cc","file_line":3941,"referenced_errors":[{"created":"@1654718949.205917750","description":"failed to connect to all addresses","file":"external/com_github_grpc_grpc/src/core/ext/filters/client_channel/lb_policy/pick_first/pick_first.cc","file_line":393,"grpc_status":14}]}
Traceback (most recent call last):
  File "/home/ec2-user/all_reduce.py", line 38, in <module>
    _mp_fn()
  File "/home/ec2-user/all_reduce.py", line 23, in _mp_fn
    xones = ones.to(device)
RuntimeError: tensorflow/compiler/xla/xla_client/xrt_computation_client.cc:436 : Check failed: session->session()->Run(session_work->feed_inputs, session_work->outputs_handles, &outputs) == ::tensorflow::Status::OK() (Unavailable: failed to connect to all addresses
Additional GRPC error information from remote target /job:localservice/replica:0/task:49:
:{"created":"@1654718949.205919712","description":"Failed to pick subchannel","file":"external/com_github_grpc_grpc/src/core/ext/filters/client_channel/client_channel.cc","file_line":3941,"referenced_errors":[{"created":"@1654718949.205917750","description":"failed to connect to all addresses","file":"external/com_github_grpc_grpc/src/core/ext/filters/client_channel/lb_policy/pick_first/pick_first.cc","file_line":393,"grpc_status":14}]} vs. OK)
*** Begin stack trace ***
    tensorflow::CurrentStackTrace()
    
    xla::util::MultiWait::Complete(std::function<void ()> const&)
    
    
    
    clone
*** End stack trace ***

Expected behavior

Environment

  • Reproducible on XLA backend [CPU/TPU]:
  • torch_xla version:1.10

Additional context

aws-rhsoln avatar Jun 08 '22 22:06 aws-rhsoln

I am not sure if torch.launch is doing the right thing here given my limited knowledge to this api.

It seem like you try to start 32 process per host(meaning 32 devices) and 7 hosts. We have only tested 8 processed but we can scale up to 250+ hosts without this error. I am not sure how you set your config here for your device. In my mind what should happen is that you will have 1 tf/xrt grpc server per host, and 32 xrt_client per process.

the grpc error is a general grpc server(not tf specified). I did a quick search and found https://stackoverflow.com/questions/65854022/grpc-failed-to-pick-subchannel-if-server-and-client-are-hosted-on-different-ma. I am guessing it is exhausting some kind of resource but without too much detail of your device config I don't have too many ideas.

JackCaoG avatar Jun 10 '22 23:06 JackCaoG

We found the issue. Due to a very high rate of DNS resolution requests during bootstrap some of the DNS requests fail. That causes some of GRPC connections to fail and that eventually bring the whole cluster down.

Why the high rate of DNS requests? We are passing the world list of workers to every worker. Each entry is "hostname:port". There are 32 workers on every instance. In the case of 8 node cluster the number of DNS queries: 32*8 (world size) * 32 (number of workers on an instance) = 8192 DNS requests.

awsilya avatar Jun 22 '22 18:06 awsilya

Can you point me to where passing the world list of workers to every worker is happening? Is this part of the XRT serever(TF grpc server) logic or XRTClient logic? I was expecting each worker only needs to talk to its XRT server and let server handle the reduction. I might be missing something here.

JackCaoG avatar Jun 22 '22 23:06 JackCaoG

A simple code to try on GPUs:

import sys
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp

import os


def _mp_fn(index):
  print('XRT_LOCAL_WORKER:{}'.format(os.environ['XRT_LOCAL_WORKER']))
  print('XRT_DEVICE_MAP:{}'.format(os.environ['XRT_DEVICE_MAP']))
  print('XRT_WORKERS:{}'.format(os.environ['XRT_WORKERS']))
  print('XRT_HOST_WORLD_SIZE:{}'.format(os.environ['XRT_HOST_WORLD_SIZE']))
  device = xm.xla_device()
  world_size = xm.xrt_world_size()
  ordinal_tensor = torch.tensor([index], dtype=torch.float).to(device)
  print('rank:{}, value:{}'.format(index, ordinal_tensor))
  result = xm.all_reduce('sum', ordinal_tensor)

  cpu_result = result.cpu()
  print('rank:{}, value:{}'.format(index, cpu_result))


if __name__ == '__main__':
  xmp.spawn(_mp_fn, args=(), nprocs=2, join=True)

amithrm avatar Jun 23 '22 20:06 amithrm

Run command:

GPU_NUM_DEVICES=2 python3 allreduce_xla.py

This will output:

XRT_LOCAL_WORKER:localservice:0 XRT_DEVICE_MAP:GPU:0;/job:localservice/replica:0/task:0/device:XLA_GPU:0|GPU:1;/job:localservice/replica:0/task:1/device:XLA_GPU:0 XRT_WORKERS:localservice:0;grpc://dfda805bbe4b:49887|localservice:1;grpc://dfda805bbe4b:33097 XRT_LOCAL_WORKER:localservice:1 XRT_DEVICE_MAP:GPU:0;/job:localservice/replica:0/task:0/device:XLA_GPU:0|GPU:1;/job:localservice/replica:0/task:1/device:XLA_GPU:0 XRT_WORKERS:localservice:0;grpc://dfda805bbe4b:49887|localservice:1;grpc://dfda805bbe4b:33097

If you look for XRT_WORKERS, this has the grpc string for each worker. This won't scale with number of workers.

amithrm avatar Jun 23 '22 20:06 amithrm

I think this might be caused by a outdated code in the XRT.. so there are 3 ways (actually 4 but one of them only works for TPU)

they are https://github.com/pytorch/xla/blob/9525a8ec69af8ecc4df066871e8932555b4759ce/third_party/xla_client/computation_client.cc#L234 https://github.com/pytorch/xla/blob/9525a8ec69af8ecc4df066871e8932555b4759ce/third_party/xla_client/computation_client.cc#L248 https://github.com/pytorch/xla/blob/9525a8ec69af8ecc4df066871e8932555b4759ce/third_party/xla_client/computation_client.cc#L169

They are being used in here https://github.com/pytorch/xla/blob/9525a8ec69af8ecc4df066871e8932555b4759ce/third_party/xla_client/computation_client.cc#L277-L281

For us we always use ParseEnvBasedTpuClusterConfig for TPU, GPU when you setup the GPU_NUM_DEVICES=x it will use ParseEnvDeviceCounts. The way you setup XRT_WORKER and device map will make you use ParseEnvDevices. I think the GRPC storm is coming from https://github.com/pytorch/xla/blob/9525a8ec69af8ecc4df066871e8932555b4759ce/third_party/xla_client/computation_client.cc#L257-L262

Note that if you use other 2 configs, you only add one worker to the workers_map in https://github.com/pytorch/xla/blob/9525a8ec69af8ecc4df066871e8932555b4759ce/third_party/xla_client/computation_client.cc#L155-L157

This code is way before my time. We only uses DEVICE_MAP for single CPU usage. My suggestion would be to configer the XRT differently. I suspect that every wokrer to every other worker is not needed, we defiantly does not do that for TPU cluster.

JackCaoG avatar Jun 23 '22 22:06 JackCaoG

Maybe you can try to verify whether using NUM_GPU_DEVICES there is actually a GRPC storm. I would suggest you to not to config using DEVICE_MAP and XRT_WORKER. Maybe you can hack it a bit and use GPU_NUM_DEVICE code path. I don't know how you actually config your ASIC.

JackCaoG avatar Jun 23 '22 22:06 JackCaoG

Hi Jack, thanks for the pointers! I went over the code flow. The xmp.spawn() code pasted above takes the same path as that of GPU_NUM_DEVICES. In my understanding, (I will confirm after taking to my team mates) that we added an extra branch along the same lines as GPU which creates the workers in the same way.

If there is a way to port the TPU config code path, that will most likely solve our problem. Is there any way we can make that code TPU agnostic? That would be the best way of resolving this. Please let us know what you think.

amithrm avatar Jun 27 '22 16:06 amithrm

yea I think replicated the gpu path might be a better idea. If you look at our GPU code path in XrtComputationClient you find some gpu specified logic like https://github.com/pytorch/xla/blob/b6a7c32cdf252428f457ba3c95c148d8526a33dc/third_party/xla_client/xrt_computation_client.cc#L1109-L1115 and https://github.com/pytorch/xla/blob/b6a7c32cdf252428f457ba3c95c148d8526a33dc/third_party/xla_client/xrt_computation_client.cc#L1459-L1467

TPU code path is a bit special. In the parsing logic https://github.com/pytorch/xla/blob/b6a7c32cdf252428f457ba3c95c148d8526a33dc/third_party/xla_client/computation_client.cc#L169-L186 we hardcode the num_tpu_device to be 8 which might or might not be a valid assumption for your device.

in the XRTComputationClinet TPU needs to issue a special op to get the TPU mesh https://github.com/pytorch/xla/blob/b6a7c32cdf252428f457ba3c95c148d8526a33dc/third_party/xla_client/xrt_computation_client.cc#L1337-L1356

and do some other things like having 1 process fetch the mesh and send it to all other clients. I felt like it is too TPU specified and hard to generalize. We are also moving to PJRT so I would prefer not to change XRT too much right now.

JackCaoG avatar Jun 28 '22 05:06 JackCaoG

so I dump the env var @amithrm

master host and master process

2022-07-08 00:06:27 10.164.0.121 [0] TPU_MESH_CONTROLLER_PORT: 8476
2022-07-08 00:06:27 10.164.0.121 [0] XRT_LOCAL_WORKER: c_localservice:0
2022-07-08 00:06:27 10.164.0.121 [0] XRT_SHARD_ORDINAL: 0
2022-07-08 00:06:27 10.164.0.121 [0] TPU_HOST_BOUNDS: 2,2,1
2022-07-08 00:06:27 10.164.0.121 [0] TPU_NUM_DEVICES: 8
2022-07-08 00:06:27 10.164.0.121 [0] XRT_TPU_CONFIG: c_localservice;0;10.164.0.121:51011|c_localservice;1;10.164.0.122:51011|c_localservice;2;10.164.0.124:51011|c_localservice;3;10.164.0.123:51011
2022-07-08 00:06:27 10.164.0.121 [0] LD_LIBRARY_PATH: :/usr/local/lib
2022-07-08 00:06:27 10.164.0.121 [0] XRT_SHARD_WORLD_SIZE: 32
2022-07-08 00:06:27 10.164.0.121 [0] TPU_MESH_CONTROLLER_ADDRESS: 10.164.0.121:8476
2022-07-08 00:06:27 10.164.0.121 [0] XRT_MESH_SERVICE_ADDRESS: 10.164.0.121:8477
2022-07-08 00:06:27 10.164.0.121 [0] TPUVM_MODE: 1
2022-07-08 00:06:27 10.164.0.121 [0] TF_CPP_MIN_LOG_LEVEL: 1
2022-07-08 00:06:27 10.164.0.121 [0] GRPC_VERBOSITY: ERROR
2022-07-08 00:06:27 10.164.0.121 [0] ALLOW_MULTIPLE_LIBTPU_LOAD: 1
2022-07-08 00:06:27 10.164.0.121 [0] XRT_START_LOCAL_SERVER: 0
2022-07-08 00:06:27 10.164.0.121 [0] TF_GRPC_DEFAULT_OPTIONS: grpc.keepalive_time_ms=60000,grpc.keepalive_timeout_ms=14400000,grpc.http2.max_pings_without_data=0,grpc.http2.min_ping_interval_without_data_ms=300000
2022-07-08 00:06:27 10.164.0.121 [0] XLA_FLAGS:  --xla_cpu_enable_fast_math=false
2022-07-08 00:06:27 10.164.0.121 [0] TPU_LIBRARY_PATH: /usr/local/lib/python3.8/dist-packages/libtpu/libtpu.so
2022-07-08 00:06:27 10.164.0.121 [0] XRT_TORCH_DIST_ROOT: t1v-n-4eb727ae-w-0.europe-west4-a.c.tpu-prod-env-one-vm.internal:34415
2022-07-08 00:06:27 10.164.0.121 [0] XRT_HOST_WORLD_SIZE: 4
2022-07-08 00:06:27 10.164.0.121 [0] XRT_SHARD_LOCAL_ORDINAL: 0
2022-07-08 00:06:27 10.164.0.121 [0] XRT_MULTI_PROCESSING_DEVICE: TPU:0

master host non master process

2022-07-08 00:13:34 10.164.0.121 [0] XLA_EMIT_STEPLOG: 1
2022-07-08 00:13:34 10.164.0.121 [0] CLOUD_TPU_TASK_ID: 0
2022-07-08 00:13:34 10.164.0.121 [0] TPU_CHIPS_PER_HOST_BOUNDS: 2,2,1
2022-07-08 00:13:34 10.164.0.121 [0] LD_PRELOAD: /usr/lib/x86_64-linux-gnu/libtcmalloc.so.4
2022-07-08 00:13:34 10.164.0.121 [0] TPU_MESH_CONTROLLER_PORT: 8476
2022-07-08 00:13:34 10.164.0.121 [0] XRT_LOCAL_WORKER: c_localservice:0
2022-07-08 00:13:34 10.164.0.121 [0] XRT_SHARD_ORDINAL: 1
2022-07-08 00:13:34 10.164.0.121 [0] TPU_HOST_BOUNDS: 2,2,1
2022-07-08 00:13:34 10.164.0.121 [0] LD_LIBRARY_PATH: :/usr/local/lib
2022-07-08 00:13:34 10.164.0.121 [0] XRT_SHARD_WORLD_SIZE: 32
2022-07-08 00:13:34 10.164.0.121 [0] TPU_MESH_CONTROLLER_ADDRESS: 10.164.0.121:8476
2022-07-08 00:13:34 10.164.0.121 [0] XRT_MESH_SERVICE_ADDRESS: 10.164.0.121:8477
2022-07-08 00:13:34 10.164.0.121 [0] TPUVM_MODE: 1
2022-07-08 00:13:34 10.164.0.121 [0] TF_CPP_MIN_LOG_LEVEL: 1
2022-07-08 00:13:34 10.164.0.121 [0] GRPC_VERBOSITY: ERROR
2022-07-08 00:13:34 10.164.0.121 [0] ALLOW_MULTIPLE_LIBTPU_LOAD: 1
2022-07-08 00:13:34 10.164.0.121 [0] XRT_START_LOCAL_SERVER: 0
2022-07-08 00:13:34 10.164.0.121 [0] TF_GRPC_DEFAULT_OPTIONS: grpc.keepalive_time_ms=60000,grpc.keepalive_timeout_ms=14400000,grpc.http2.max_pings_without_data=0,grpc.http2.min_ping_interval_without_data_ms=300000
2022-07-08 00:13:34 10.164.0.121 [0] XLA_FLAGS:  --xla_cpu_enable_fast_math=false
2022-07-08 00:13:34 10.164.0.121 [0] TPU_LIBRARY_PATH: /usr/local/lib/python3.8/dist-packages/libtpu/libtpu.so
2022-07-08 00:13:34 10.164.0.121 [0] XRT_TORCH_DIST_ROOT: t1v-n-4eb727ae-w-0.europe-west4-a.c.tpu-prod-env-one-vm.internal:55971
2022-07-08 00:13:34 10.164.0.121 [0] XRT_HOST_WORLD_SIZE: 4
2022-07-08 00:13:34 10.164.0.121 [0] XRT_SHARD_LOCAL_ORDINAL: 1
2022-07-08 00:13:34 10.164.0.121 [0] XRT_MULTI_PROCESSING_DEVICE: TPU:1

JackCaoG avatar Jul 08 '22 00:07 JackCaoG

non-master host and master process

2022-07-08 00:22:45 10.164.0.123 [3] TPU_CHIPS_PER_HOST_BOUNDS: 2,2,1
2022-07-08 00:22:45 10.164.0.123 [3] LD_PRELOAD: /usr/lib/x86_64-linux-gnu/libtcmalloc.so.4
2022-07-08 00:22:45 10.164.0.123 [3] TPU_MESH_CONTROLLER_PORT: 8476
2022-07-08 00:22:45 10.164.0.123 [3] XRT_LOCAL_WORKER: c_localservice:3
2022-07-08 00:22:45 10.164.0.123 [3] XRT_SHARD_ORDINAL: 24
2022-07-08 00:22:45 10.164.0.123 [3] TPU_HOST_BOUNDS: 2,2,1
2022-07-08 00:22:45 10.164.0.123 [3] LD_LIBRARY_PATH: :/usr/local/lib
2022-07-08 00:22:45 10.164.0.123 [3] XRT_SHARD_WORLD_SIZE: 32
2022-07-08 00:22:45 10.164.0.123 [3] TPU_MESH_CONTROLLER_ADDRESS: 10.164.0.121:8476
2022-07-08 00:22:45 10.164.0.123 [3] XRT_MESH_SERVICE_ADDRESS: 10.164.0.121:8477
2022-07-08 00:22:45 10.164.0.123 [3] TPUVM_MODE: 1
2022-07-08 00:22:45 10.164.0.123 [3] TF_CPP_MIN_LOG_LEVEL: 1
2022-07-08 00:22:45 10.164.0.123 [3] GRPC_VERBOSITY: ERROR
2022-07-08 00:22:45 10.164.0.123 [3] ALLOW_MULTIPLE_LIBTPU_LOAD: 1
2022-07-08 00:22:45 10.164.0.123 [3] XRT_START_LOCAL_SERVER: 0
2022-07-08 00:22:45 10.164.0.123 [3] TF_GRPC_DEFAULT_OPTIONS: grpc.keepalive_time_ms=60000,grpc.keepalive_timeout_ms=14400000,grpc.http2.max_pings_without_data=0,grpc.http2.min_ping_interval_without_data_ms=300000
2022-07-08 00:22:45 10.164.0.123 [3] XLA_FLAGS:  --xla_cpu_enable_fast_math=false
2022-07-08 00:22:45 10.164.0.123 [3] TPU_LIBRARY_PATH: /usr/local/lib/python3.8/dist-packages/libtpu/libtpu.so
2022-07-08 00:22:45 10.164.0.123 [3] XRT_TORCH_DIST_ROOT: t1v-n-4eb727ae-w-3.europe-west4-a.c.tpu-prod-env-one-vm.internal:40199
2022-07-08 00:22:45 10.164.0.123 [3] XRT_HOST_WORLD_SIZE: 4
2022-07-08 00:22:45 10.164.0.123 [3] XRT_SHARD_LOCAL_ORDINAL: 0
2022-07-08 00:22:45 10.164.0.123 [3] XRT_MULTI_PROCESSING_DEVICE: TPU:24

non-master host and non-master process

2022-07-08 00:26:50 10.164.0.123 [3] XLA_EMIT_STEPLOG: 1
2022-07-08 00:26:50 10.164.0.123 [3] CLOUD_TPU_TASK_ID: 3
2022-07-08 00:26:50 10.164.0.123 [3] TPU_CHIPS_PER_HOST_BOUNDS: 2,2,1
2022-07-08 00:26:50 10.164.0.123 [3] LD_PRELOAD: /usr/lib/x86_64-linux-gnu/libtcmalloc.so.4
2022-07-08 00:26:50 10.164.0.123 [3] TPU_MESH_CONTROLLER_PORT: 8476
2022-07-08 00:26:50 10.164.0.123 [3] XRT_LOCAL_WORKER: c_localservice:3
2022-07-08 00:26:50 10.164.0.123 [3] XRT_SHARD_ORDINAL: 25
2022-07-08 00:26:50 10.164.0.123 [3] TPU_HOST_BOUNDS: 2,2,1
2022-07-08 00:26:50 10.164.0.123 [3] LD_LIBRARY_PATH: :/usr/local/lib
2022-07-08 00:26:50 10.164.0.123 [3] XRT_SHARD_WORLD_SIZE: 32
2022-07-08 00:26:50 10.164.0.123 [3] PATH: /usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/games:/usr/local/games:/snap/bin
2022-07-08 00:26:50 10.164.0.123 [3] TPU_MESH_CONTROLLER_ADDRESS: 10.164.0.121:8476
2022-07-08 00:26:50 10.164.0.123 [3] XRT_MESH_SERVICE_ADDRESS: 10.164.0.121:8477
2022-07-08 00:26:50 10.164.0.123 [3] TPUVM_MODE: 1
2022-07-08 00:26:50 10.164.0.123 [3] TF_CPP_MIN_LOG_LEVEL: 1
2022-07-08 00:26:50 10.164.0.123 [3] GRPC_VERBOSITY: ERROR
2022-07-08 00:26:50 10.164.0.123 [3] ALLOW_MULTIPLE_LIBTPU_LOAD: 1
2022-07-08 00:26:50 10.164.0.123 [3] XRT_START_LOCAL_SERVER: 0
2022-07-08 00:26:50 10.164.0.123 [3] TF_GRPC_DEFAULT_OPTIONS: grpc.keepalive_time_ms=60000,grpc.keepalive_timeout_ms=14400000,grpc.http2.max_pings_without_data=0,grpc.http2.min_ping_interval_without_data_ms=300000
2022-07-08 00:26:50 10.164.0.123 [3] XLA_FLAGS:  --xla_cpu_enable_fast_math=false
2022-07-08 00:26:50 10.164.0.123 [3] TPU_LIBRARY_PATH: /usr/local/lib/python3.8/dist-packages/libtpu/libtpu.so
2022-07-08 00:26:50 10.164.0.123 [3] XRT_TORCH_DIST_ROOT: t1v-n-4eb727ae-w-3.europe-west4-a.c.tpu-prod-env-one-vm.internal:35145
2022-07-08 00:26:50 10.164.0.123 [3] XRT_HOST_WORLD_SIZE: 4
2022-07-08 00:26:50 10.164.0.123 [3] XRT_SHARD_LOCAL_ORDINAL: 1
2022-07-08 00:26:50 10.164.0.123 [3] XRT_MULTI_PROCESSING_DEVICE: TPU:25

JackCaoG avatar Jul 08 '22 00:07 JackCaoG