xla icon indicating copy to clipboard operation
xla copied to clipboard

[torch-xla v2.8] Error "Check failed: state Expected an array shape." when running test/test_mp_reduce_scatter.py

Open jeffhataws opened this issue 6 months ago • 12 comments

🐛 Bug

With torch-xla v2.8, the Neuron team is getting "Check failed: state Expected an array shape." errors when running many training tests that uses reduce-scatter. These errors were not there in v2.7. Furthermore, I have narrowed it down to commit that updated openxla pin https://github.com/pytorch/xla/pull/9045 , because using the torch-xla nightly from 4/30 works.

I have also narrowed down the testcase to the existing test/test_mp_reduce_scatter.py as seen in the next section.

To Reproduce

On a CPU instance:

python -m venv test_venv_pt2.8
source test_venv_pt2.8/bin/activate
pip3 install -U pip
pip3 install torch torchvision --pre --extra-index-url https://download.pytorch.org/whl/nightly/cpu
pip3 install 'torch_xla @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev-cp310-cp310-linux_x86_64.whl' -f https://storage.googleapis.com/libtpu-releases/index.html
cd ~/
git clone https://github.com/pytorch/xla
cd ~/xla/test

Add 'CPU' to the device list in test_mp_reduce_scatter.py:

   if xm.xla_device_hw(device) in ['TPU', 'CUDA', 'NEURON', 'CPU']:

then run

PJRT_DEVICE=CPU python test_mp_reduce_scatter.py
WARNING:root:MASTER_ADDR environment variable is not set, defaulting to localhost
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR                                                      
F0000 00:00:1749493093.448876   48267 shape.cc:166] Check failed: state Expected an array shape. Got (f32[32,2,32], f32[32,2,32], f32[32,2,32], f32[32,2,32], f32[32,2,32])
This is a programmer error. Please read the Shape object's array properties (e.g. dimensions) only when it's an array shape.
*** Check failure stack trace: ***                                                                                                                                           
    @     0x7ea70c886e99  absl::lts_20230802::log_internal::LogMessageFatal::~LogMessageFatal()
    @     0x7ea7017dc8c7  xla::Shape::array_state()
    @     0x7ea701a3b1ce  torch_xla::BuildReduceScatterCoalesced()
    @     0x7ea701e3b264  std::_Function_handler<>::_M_invoke()
    @     0x7ea701ddf751  torch_xla::InferOutputShape()
    @     0x7ea701e3b4a1  std::_Function_handler<>::_M_invoke()
    @     0x7ea701e726df  torch_xla::XlaNode::GetOpShape()
    @     0x7ea701e72fa9  torch_xla::XlaNode::XlaNode()
    @     0x7ea701e3c2ec  torch_xla::ReduceScatterCoalesced::ReduceScatterCoalesced()
    @     0x7ea701ade1b7  torch_xla::MakeNode<>()
    @     0x7ea701ade432  torch_xla::tensor_methods::reduce_scatter_coalesced()
    @     0x7ea70194c02f  torch_xla::(anonymous namespace)::InitXlaModuleBindings()::{lambda()#55}::operator()()
    @     0x7ea70196ca19  pybind11::cpp_function::initialize<>()::{lambda()#3}::_FUN() 
    @     0x7ea701948e38  pybind11::cpp_function::dispatcher()
    @     0x61fe15737e12  (unknown)
Traceback (most recent call last):
  File "/home/ubuntu/xla/test/test_mp_reduce_scatter.py", line 180, in <module>
    torch_xla.launch(_mp_fn, args=())
  File "/home/ubuntu/test_venv/lib/python3.10/site-packages/torch_xla/torch_xla.py", line 266, in launch
    xmp.spawn(fn, args=args, nprocs=nprocs, start_method=start_method)
  File "/home/ubuntu/test_venv/lib/python3.10/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 43, in spawn
    return pjrt.spawn(fn, nprocs, start_method, args)
  File "/home/ubuntu/test_venv/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 213, in spawn
    run_multiprocess(spawn_fn, start_method=start_method)
  File "/home/ubuntu/test_venv/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 169, in run_multiprocess
    replica_results = list(
  File "/home/ubuntu/test_venv/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 170, in <genexpr>
    itertools.chain.from_iterable(
  File "/usr/lib/python3.10/concurrent/futures/process.py", line 575, in _chain_from_iterable_of_lists
    for element in iterable:
  File "/usr/lib/python3.10/concurrent/futures/_base.py", line 621, in result_iterator 
    yield _result_or_cancel(fs.pop())
  File "/usr/lib/python3.10/concurrent/futures/_base.py", line 319, in _result_or_cancel
    return fut.result(timeout)
  File "/usr/lib/python3.10/concurrent/futures/_base.py", line 458, in result
    return self.__get_result()
  File "/usr/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
    raise self._exception
concurrent.futures.process.BrokenProcessPool: A process in the process pool was terminated abruptly while the future was running or pending.

Expected behavior

No crash when running the code with v2.7.

python -m venv test_venv_pt2.7
source test_venv_pt2.7/bin/activate
pip3 install -U pip
pip install torch torch-xla
cd ~/xla/test
PJRT_DEVICE=CPU python test_mp_reduce_scatter.py

Environment

  • Reproducible on XLA backend [CPU/TPU/CUDA]: CPU, NEURON
  • torch_xla version: 2.8 (TOT)

Additional context

jeffhataws avatar Jun 09 '25 18:06 jeffhataws

Let's enable these tests on CPU if possible. We will also enable this particular one for Neuron.

jeffhataws avatar Jun 09 '25 18:06 jeffhataws

I think that this should be covered as part of https://github.com/pytorch/xla/issues/9315

bhavya01 avatar Jun 10 '25 22:06 bhavya01

@bhavya01 I don't think that #9315 is related as it covers compilation of reduce_scatter. This is not a compilation issue. I believe it's a failure of reduce_scatter_bucketized, which exercises a different code path than reduce_scatter, and which I have not used before. reduce_scatter appears to still work in non-compiled mode.

I agree with @jeffhataws that we should add "CPU" to the device list, since at the moment these tests are not exercised during CI.

As for the issue, the relevant change to XLA seems to be here, but that just moves the check from shape.h to shape.cc as far as I can tell so I don't know why it would introduce an error. The actual check was added on March 30th, here. I'll ask the XLA team for help

bfolie avatar Jun 10 '25 23:06 bfolie

Thanks for taking this @bfolie

bhavya01 avatar Jun 11 '25 17:06 bhavya01

One problem is here, where it calls ShapeHelper::ShapeOfXlaOp(reduce_result).dimensions_size() == 0. Previously, if the result was a tuple of shapes, .dimensions_size would return 0. Now it crashes (XLA PR here). The correct approach is to use ShapeHelper::ShapeOfXlaOp(reduce_result).IsTuple()

Making that change, as in #9347, causes test_mp_reduce_scatter.py to pass on CPU, but it still fails on TPU (the operative difference is likely not CPU vs. TPU but 1 process vs. multiple processes). It's the same error message but there's no useful information in the C++ stack trace (below). The error occurs inside this line, client_->CompileAndLoad(instance.computation, compile_options);. But that can't be the underlying cause. My hunch is that something about the ReduceScatter lowering code is calling an invalid shape op (treating a list of shapes as if its a single shape). But nothing has jumped out to me yet. And that doesn't explain why it would pass on CPU but not TPU.

*** Check failure stack trace: ***
    @     0x7f509eeefc64  (unknown)
    @     0x7f509eeefc18  (unknown)
    @     0x7f509df9ff45  (unknown)
    @     0x7f50973ee1ee  (unknown)
    @     0x7f509732345c  (unknown)
    @     0x7f509731fdfb  (unknown)
    @     0x7f5096a31658  (unknown)
    @     0x7f5096a3bf38  (unknown)
    @     0x7f5096a40485  (unknown)
    @     0x7f509be7a264  (unknown)
    @     0x7f509be79acf  (unknown)
    @     0x7f509648c746  (unknown)
    @     0x7f5099087413  (unknown)
    @     0x7f509ead0085  (unknown)
    @     0x7f509ead53b6  (unknown)
    @     0x7f509eade272  (unknown)
    @     0x7f509ed4b543  (unknown)
    @     0x7f57d873dea7  start_thread
https://symbolize.stripped_domain/r/?trace=7f509eeefc64,7f509eeefc17,7f509df9ff44,7f50973ee1ed,7f509732345b,7f509731fdfa,7f5096a31657,7f5096a3bf37,7f5096a40484,7f509be7a263,7f509be79ace,7f509648c745,7f5099087412,7f509ead0084,7f509ead53b5,7f509eade271,7f509ed4b542,7f57d873dea6&map= 
https://symbolize.stripped_domain/r/?trace=7f57d8790d51,7f57d8790dcf,7f509eeefcc8,7f509eeefc17,7f509df9ff44,7f50973ee1ed,7f509732345b,7f509731fdfa,7f5096a31657,7f5096a3bf37,7f5096a40484,7f509be7a263,7f509be79ace,7f509648c745,7f5099087412,7f509ead0084,7f509ead53b5,7f509eade271,7f509ed4b542,7f57d873dea6&map= 
*** SIGABRT received by PID 3088866 (TID 3091245) on cpu 223 from PID 3088866; ***
E0612 06:15:07.981904 3091245 coredump_hook.cc:301] RAW: Remote crash data gathering hook invoked.
E0612 06:15:07.981912 3091245 client.cc:270] RAW: Coroner client retries enabled, will retry for up to 30 sec.
E0612 06:15:07.981920 3091245 coredump_hook.cc:396] RAW: Sending fingerprint to remote end.
E0612 06:15:07.981934 3091245 coredump_hook.cc:405] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] stat failed on crash reporting socket /var/google/services/logmanagerd/remote_coredump.socket (Is the listener running?): No such file or directory
E0612 06:15:07.981939 3091245 coredump_hook.cc:457] RAW: Dumping core locally.
F0612 06:15:07.785581 3091210 shape.cc:166] Check failed: state Expected an array shape. Got (f32[32,2,32], f32[32,2,32], f32[32,2,32], f32[32,2,32], f32[32,2,32])
This is a programmer error. Please read the Shape object's array properties (e.g. dimensions) only when it's an array shape.
E0612 06:15:43.987402 3091210 process_state.cc:808] RAW: Raising signal 6 with default behavior
https://symbolize.stripped_domain/r/?trace=7f0e30d4e7b2,7f0e30d9adcf&map= 
E0612 06:15:45.128389 3089552 process_state.cc:1175] RAW: Signal 15 raised at PC: 0x7f0e30d4e7b2 while already in FailureSignalHandler!
E0612 06:15:45.128415 3089552 process_state.cc:1179] RAW: tid: 3089552 raised new signal (old_tid: 3091098)
https://symbolize.stripped_domain/r/?trace=7f57d87447b2,7f57d8790dcf&map= 
E0612 06:15:45.132369 3089209 process_state.cc:1175] RAW: Signal 15 raised at PC: 0x7f57d87447b2 while already in FailureSignalHandler!
E0612 06:15:45.132396 3089209 process_state.cc:1179] RAW: tid: 3089209 raised new signal (old_tid: 3091245)
https://symbolize.stripped_domain/r/?trace=7f3716efe13c,7f3e4e7d2dcf,7f3716e76e44,7f3716f00833,7f370e174e59,7f370e174971,7f3716d4b542,7f3e4e77fea6&map= 
E0612 06:15:45.132588 3091562 process_state.cc:1175] RAW: Signal 15 raised at PC: 0x7f3716efe13c while already in FailureSignalHandler!
E0612 06:15:45.132614 3091562 process_state.cc:1179] RAW: tid: 3091562 raised new signal (old_tid: 3091664)
F0612 06:15:07.804218 3091098 shape.cc:166] Check failed: state Expected an array shape. Got (f32[32,2,32], f32[32,2,32], f32[32,2,32], f32[32,2,32], f32[32,2,32])
This is a programmer error. Please read the Shape object's array properties (e.g. dimensions) only when it's an array shape.
E0612 06:15:45.207711 3091098 process_state.cc:808] RAW: Raising signal 6 with default behavior
F0612 06:15:07.830550 3091245 shape.cc:166] Check failed: state Expected an array shape. Got (f32[32,2,32], f32[32,2,32], f32[32,2,32], f32[32,2,32], f32[32,2,32])
This is a programmer error. Please read the Shape object's array properties (e.g. dimensions) only when it's an array shape.
E0612 06:15:45.211959 3091245 process_state.cc:808] RAW: Raising signal 6 with default behavior
F0612 06:15:07.833153 3091664 shape.cc:166] Check failed: state Expected an array shape. Got (f32[32,2,32], f32[32,2,32], f32[32,2,32], f32[32,2,32], f32[32,2,32])
This is a programmer error. Please read the Shape object's array properties (e.g. dimensions) only when it's an array shape.
E0612 06:15:45.228210 3091664 process_state.cc:808] RAW: Raising signal 6 with default behavior
Traceback (most recent call last):
  File "/workspaces/torch/pytorch/xla/test/test_mp_reduce_scatter.py", line 182, in <module>
    torch_xla.launch(_mp_fn, args=())
  File "/workspaces/torch/pytorch/xla/torch_xla/torch_xla.py", line 266, in launch
    xmp.spawn(fn, args=args, nprocs=nprocs, start_method=start_method)
  File "/workspaces/torch/pytorch/xla/torch_xla/distributed/xla_multiprocessing.py", line 43, in spawn
    return pjrt.spawn(fn, nprocs, start_method, args)
  File "/workspaces/torch/pytorch/xla/torch_xla/_internal/pjrt.py", line 213, in spawn
    run_multiprocess(spawn_fn, start_method=start_method)
  File "/workspaces/torch/pytorch/xla/torch_xla/_internal/pjrt.py", line 169, in run_multiprocess
    replica_results = list(
  File "/workspaces/torch/pytorch/xla/torch_xla/_internal/pjrt.py", line 170, in <genexpr>
    itertools.chain.from_iterable(
  File "/usr/local/lib/python3.10/concurrent/futures/process.py", line 575, in _chain_from_iterable_of_lists
    for element in iterable:
  File "/usr/local/lib/python3.10/concurrent/futures/_base.py", line 621, in result_iterator
    yield _result_or_cancel(fs.pop())
  File "/usr/local/lib/python3.10/concurrent/futures/_base.py", line 319, in _result_or_cancel
    return fut.result(timeout)
  File "/usr/local/lib/python3.10/concurrent/futures/_base.py", line 458, in result
    return self.__get_result()
  File "/usr/local/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
    raise self._exception
concurrent.futures.process.BrokenProcessPool: A process in the process pool was terminated abruptly while the future was running or pending.

bfolie avatar Jun 12 '25 06:06 bfolie

Thanks @bfolie . Do you know if this line is also affected?

jeffhataws avatar Jun 12 '25 20:06 jeffhataws

I don't think so. .tuple_shapes() will work as long as all_gather_result has a tuple shape, and I think in that context it does. But I haven't tested it

bfolie avatar Jun 12 '25 20:06 bfolie

Thanks @bfolie for the fix! I have confirmed that it works for Neuron.

jeffhataws avatar Jun 17 '25 04:06 jeffhataws

For some reason multi-node all-gather is now crashing. Let me debug and isolate a testcase. The crash trace is below in case you know where to look:

F0616 05:45:23.533522    7492 shape.cc:184] Check failed: state Expected a tuple shape. Got bf16[4008,4096]
This is a programmer error. Please read the Shape object's tuple properties (e.g. tuple_shapes) only when it's a tuple shape.
*** Check failure stack trace: ***
    @     0x7ff941f4c039  absl::lts_20230802::log_internal::LogMessageFatal::~LogMessageFatal()
    @     0x7ff936e4fae2  xla::Shape::tuple_state()
    @     0x7ff941cd6bc9  xla::Shape::tuple_shapes()
    @     0x7ff9370baf9b  torch_xla::BuildAllGatherCoalesced()
    @     0x7ff937410966  std::_Function_handler<>::_M_invoke()
    @     0x7ff93745fdf1  torch_xla::InferOutputShape()
    @     0x7ff937410b73  std::_Function_handler<>::_M_invoke()
    @     0x7ff9374f2d7f  torch_xla::XlaNode::GetOpShape()
    @     0x7ff9374f3649  torch_xla::XlaNode::XlaNode()
    @     0x7ff937411a55  torch_xla::AllGatherCoalesced::AllGatherCoalesced()
    @     0x7ff93715d7fe  torch_xla::MakeNode<>()
    @     0x7ff93715dacf  torch_xla::tensor_methods::all_gather_coalesced()
    @     0x7ff936fc36c5  torch_xla::(anonymous namespace)::InitXlaModuleBindings()::{lambda()#45}::operator()()
    @     0x7ff936fee0bb  pybind11::cpp_function::initialize<>()::{lambda()#3}::_FUN()
    @     0x7ff936fbf618  pybind11::cpp_function::dispatcher()
    @     0x564d9122de12  (unknown)

jeffhataws avatar Jun 17 '25 18:06 jeffhataws

@bfolie https://github.com/pytorch/xla/pull/9403 is the fix for the allgather issue above. I will try to narrow down to a smaller unit test.

jeffhataws avatar Jun 24 '25 17:06 jeffhataws

Narrowed down to single-node test, but it still has NeuronX Distributed and real dataset dependencies. Will narrow down some more.

jeffhataws avatar Jun 27 '25 16:06 jeffhataws

Seems to happen with TP + ZeRO1.

jeffhataws avatar Jun 27 '25 16:06 jeffhataws