[torch-xla v2.8] Error "Check failed: state Expected an array shape." when running test/test_mp_reduce_scatter.py
🐛 Bug
With torch-xla v2.8, the Neuron team is getting "Check failed: state Expected an array shape." errors when running many training tests that uses reduce-scatter. These errors were not there in v2.7. Furthermore, I have narrowed it down to commit that updated openxla pin https://github.com/pytorch/xla/pull/9045 , because using the torch-xla nightly from 4/30 works.
I have also narrowed down the testcase to the existing test/test_mp_reduce_scatter.py as seen in the next section.
To Reproduce
On a CPU instance:
python -m venv test_venv_pt2.8
source test_venv_pt2.8/bin/activate
pip3 install -U pip
pip3 install torch torchvision --pre --extra-index-url https://download.pytorch.org/whl/nightly/cpu
pip3 install 'torch_xla @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev-cp310-cp310-linux_x86_64.whl' -f https://storage.googleapis.com/libtpu-releases/index.html
cd ~/
git clone https://github.com/pytorch/xla
cd ~/xla/test
Add 'CPU' to the device list in test_mp_reduce_scatter.py:
if xm.xla_device_hw(device) in ['TPU', 'CUDA', 'NEURON', 'CPU']:
then run
PJRT_DEVICE=CPU python test_mp_reduce_scatter.py
WARNING:root:MASTER_ADDR environment variable is not set, defaulting to localhost
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
F0000 00:00:1749493093.448876 48267 shape.cc:166] Check failed: state Expected an array shape. Got (f32[32,2,32], f32[32,2,32], f32[32,2,32], f32[32,2,32], f32[32,2,32])
This is a programmer error. Please read the Shape object's array properties (e.g. dimensions) only when it's an array shape.
*** Check failure stack trace: ***
@ 0x7ea70c886e99 absl::lts_20230802::log_internal::LogMessageFatal::~LogMessageFatal()
@ 0x7ea7017dc8c7 xla::Shape::array_state()
@ 0x7ea701a3b1ce torch_xla::BuildReduceScatterCoalesced()
@ 0x7ea701e3b264 std::_Function_handler<>::_M_invoke()
@ 0x7ea701ddf751 torch_xla::InferOutputShape()
@ 0x7ea701e3b4a1 std::_Function_handler<>::_M_invoke()
@ 0x7ea701e726df torch_xla::XlaNode::GetOpShape()
@ 0x7ea701e72fa9 torch_xla::XlaNode::XlaNode()
@ 0x7ea701e3c2ec torch_xla::ReduceScatterCoalesced::ReduceScatterCoalesced()
@ 0x7ea701ade1b7 torch_xla::MakeNode<>()
@ 0x7ea701ade432 torch_xla::tensor_methods::reduce_scatter_coalesced()
@ 0x7ea70194c02f torch_xla::(anonymous namespace)::InitXlaModuleBindings()::{lambda()#55}::operator()()
@ 0x7ea70196ca19 pybind11::cpp_function::initialize<>()::{lambda()#3}::_FUN()
@ 0x7ea701948e38 pybind11::cpp_function::dispatcher()
@ 0x61fe15737e12 (unknown)
Traceback (most recent call last):
File "/home/ubuntu/xla/test/test_mp_reduce_scatter.py", line 180, in <module>
torch_xla.launch(_mp_fn, args=())
File "/home/ubuntu/test_venv/lib/python3.10/site-packages/torch_xla/torch_xla.py", line 266, in launch
xmp.spawn(fn, args=args, nprocs=nprocs, start_method=start_method)
File "/home/ubuntu/test_venv/lib/python3.10/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 43, in spawn
return pjrt.spawn(fn, nprocs, start_method, args)
File "/home/ubuntu/test_venv/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 213, in spawn
run_multiprocess(spawn_fn, start_method=start_method)
File "/home/ubuntu/test_venv/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 169, in run_multiprocess
replica_results = list(
File "/home/ubuntu/test_venv/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 170, in <genexpr>
itertools.chain.from_iterable(
File "/usr/lib/python3.10/concurrent/futures/process.py", line 575, in _chain_from_iterable_of_lists
for element in iterable:
File "/usr/lib/python3.10/concurrent/futures/_base.py", line 621, in result_iterator
yield _result_or_cancel(fs.pop())
File "/usr/lib/python3.10/concurrent/futures/_base.py", line 319, in _result_or_cancel
return fut.result(timeout)
File "/usr/lib/python3.10/concurrent/futures/_base.py", line 458, in result
return self.__get_result()
File "/usr/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
raise self._exception
concurrent.futures.process.BrokenProcessPool: A process in the process pool was terminated abruptly while the future was running or pending.
Expected behavior
No crash when running the code with v2.7.
python -m venv test_venv_pt2.7
source test_venv_pt2.7/bin/activate
pip3 install -U pip
pip install torch torch-xla
cd ~/xla/test
PJRT_DEVICE=CPU python test_mp_reduce_scatter.py
Environment
- Reproducible on XLA backend [CPU/TPU/CUDA]: CPU, NEURON
- torch_xla version: 2.8 (TOT)
Additional context
Let's enable these tests on CPU if possible. We will also enable this particular one for Neuron.
I think that this should be covered as part of https://github.com/pytorch/xla/issues/9315
@bhavya01 I don't think that #9315 is related as it covers compilation of reduce_scatter. This is not a compilation issue. I believe it's a failure of reduce_scatter_bucketized, which exercises a different code path than reduce_scatter, and which I have not used before. reduce_scatter appears to still work in non-compiled mode.
I agree with @jeffhataws that we should add "CPU" to the device list, since at the moment these tests are not exercised during CI.
As for the issue, the relevant change to XLA seems to be here, but that just moves the check from shape.h to shape.cc as far as I can tell so I don't know why it would introduce an error. The actual check was added on March 30th, here. I'll ask the XLA team for help
Thanks for taking this @bfolie
One problem is here, where it calls ShapeHelper::ShapeOfXlaOp(reduce_result).dimensions_size() == 0. Previously, if the result was a tuple of shapes, .dimensions_size would return 0. Now it crashes (XLA PR here). The correct approach is to use ShapeHelper::ShapeOfXlaOp(reduce_result).IsTuple()
Making that change, as in #9347, causes test_mp_reduce_scatter.py to pass on CPU, but it still fails on TPU (the operative difference is likely not CPU vs. TPU but 1 process vs. multiple processes). It's the same error message but there's no useful information in the C++ stack trace (below). The error occurs inside this line, client_->CompileAndLoad(instance.computation, compile_options);. But that can't be the underlying cause. My hunch is that something about the ReduceScatter lowering code is calling an invalid shape op (treating a list of shapes as if its a single shape). But nothing has jumped out to me yet. And that doesn't explain why it would pass on CPU but not TPU.
*** Check failure stack trace: ***
@ 0x7f509eeefc64 (unknown)
@ 0x7f509eeefc18 (unknown)
@ 0x7f509df9ff45 (unknown)
@ 0x7f50973ee1ee (unknown)
@ 0x7f509732345c (unknown)
@ 0x7f509731fdfb (unknown)
@ 0x7f5096a31658 (unknown)
@ 0x7f5096a3bf38 (unknown)
@ 0x7f5096a40485 (unknown)
@ 0x7f509be7a264 (unknown)
@ 0x7f509be79acf (unknown)
@ 0x7f509648c746 (unknown)
@ 0x7f5099087413 (unknown)
@ 0x7f509ead0085 (unknown)
@ 0x7f509ead53b6 (unknown)
@ 0x7f509eade272 (unknown)
@ 0x7f509ed4b543 (unknown)
@ 0x7f57d873dea7 start_thread
https://symbolize.stripped_domain/r/?trace=7f509eeefc64,7f509eeefc17,7f509df9ff44,7f50973ee1ed,7f509732345b,7f509731fdfa,7f5096a31657,7f5096a3bf37,7f5096a40484,7f509be7a263,7f509be79ace,7f509648c745,7f5099087412,7f509ead0084,7f509ead53b5,7f509eade271,7f509ed4b542,7f57d873dea6&map=
https://symbolize.stripped_domain/r/?trace=7f57d8790d51,7f57d8790dcf,7f509eeefcc8,7f509eeefc17,7f509df9ff44,7f50973ee1ed,7f509732345b,7f509731fdfa,7f5096a31657,7f5096a3bf37,7f5096a40484,7f509be7a263,7f509be79ace,7f509648c745,7f5099087412,7f509ead0084,7f509ead53b5,7f509eade271,7f509ed4b542,7f57d873dea6&map=
*** SIGABRT received by PID 3088866 (TID 3091245) on cpu 223 from PID 3088866; ***
E0612 06:15:07.981904 3091245 coredump_hook.cc:301] RAW: Remote crash data gathering hook invoked.
E0612 06:15:07.981912 3091245 client.cc:270] RAW: Coroner client retries enabled, will retry for up to 30 sec.
E0612 06:15:07.981920 3091245 coredump_hook.cc:396] RAW: Sending fingerprint to remote end.
E0612 06:15:07.981934 3091245 coredump_hook.cc:405] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] stat failed on crash reporting socket /var/google/services/logmanagerd/remote_coredump.socket (Is the listener running?): No such file or directory
E0612 06:15:07.981939 3091245 coredump_hook.cc:457] RAW: Dumping core locally.
F0612 06:15:07.785581 3091210 shape.cc:166] Check failed: state Expected an array shape. Got (f32[32,2,32], f32[32,2,32], f32[32,2,32], f32[32,2,32], f32[32,2,32])
This is a programmer error. Please read the Shape object's array properties (e.g. dimensions) only when it's an array shape.
E0612 06:15:43.987402 3091210 process_state.cc:808] RAW: Raising signal 6 with default behavior
https://symbolize.stripped_domain/r/?trace=7f0e30d4e7b2,7f0e30d9adcf&map=
E0612 06:15:45.128389 3089552 process_state.cc:1175] RAW: Signal 15 raised at PC: 0x7f0e30d4e7b2 while already in FailureSignalHandler!
E0612 06:15:45.128415 3089552 process_state.cc:1179] RAW: tid: 3089552 raised new signal (old_tid: 3091098)
https://symbolize.stripped_domain/r/?trace=7f57d87447b2,7f57d8790dcf&map=
E0612 06:15:45.132369 3089209 process_state.cc:1175] RAW: Signal 15 raised at PC: 0x7f57d87447b2 while already in FailureSignalHandler!
E0612 06:15:45.132396 3089209 process_state.cc:1179] RAW: tid: 3089209 raised new signal (old_tid: 3091245)
https://symbolize.stripped_domain/r/?trace=7f3716efe13c,7f3e4e7d2dcf,7f3716e76e44,7f3716f00833,7f370e174e59,7f370e174971,7f3716d4b542,7f3e4e77fea6&map=
E0612 06:15:45.132588 3091562 process_state.cc:1175] RAW: Signal 15 raised at PC: 0x7f3716efe13c while already in FailureSignalHandler!
E0612 06:15:45.132614 3091562 process_state.cc:1179] RAW: tid: 3091562 raised new signal (old_tid: 3091664)
F0612 06:15:07.804218 3091098 shape.cc:166] Check failed: state Expected an array shape. Got (f32[32,2,32], f32[32,2,32], f32[32,2,32], f32[32,2,32], f32[32,2,32])
This is a programmer error. Please read the Shape object's array properties (e.g. dimensions) only when it's an array shape.
E0612 06:15:45.207711 3091098 process_state.cc:808] RAW: Raising signal 6 with default behavior
F0612 06:15:07.830550 3091245 shape.cc:166] Check failed: state Expected an array shape. Got (f32[32,2,32], f32[32,2,32], f32[32,2,32], f32[32,2,32], f32[32,2,32])
This is a programmer error. Please read the Shape object's array properties (e.g. dimensions) only when it's an array shape.
E0612 06:15:45.211959 3091245 process_state.cc:808] RAW: Raising signal 6 with default behavior
F0612 06:15:07.833153 3091664 shape.cc:166] Check failed: state Expected an array shape. Got (f32[32,2,32], f32[32,2,32], f32[32,2,32], f32[32,2,32], f32[32,2,32])
This is a programmer error. Please read the Shape object's array properties (e.g. dimensions) only when it's an array shape.
E0612 06:15:45.228210 3091664 process_state.cc:808] RAW: Raising signal 6 with default behavior
Traceback (most recent call last):
File "/workspaces/torch/pytorch/xla/test/test_mp_reduce_scatter.py", line 182, in <module>
torch_xla.launch(_mp_fn, args=())
File "/workspaces/torch/pytorch/xla/torch_xla/torch_xla.py", line 266, in launch
xmp.spawn(fn, args=args, nprocs=nprocs, start_method=start_method)
File "/workspaces/torch/pytorch/xla/torch_xla/distributed/xla_multiprocessing.py", line 43, in spawn
return pjrt.spawn(fn, nprocs, start_method, args)
File "/workspaces/torch/pytorch/xla/torch_xla/_internal/pjrt.py", line 213, in spawn
run_multiprocess(spawn_fn, start_method=start_method)
File "/workspaces/torch/pytorch/xla/torch_xla/_internal/pjrt.py", line 169, in run_multiprocess
replica_results = list(
File "/workspaces/torch/pytorch/xla/torch_xla/_internal/pjrt.py", line 170, in <genexpr>
itertools.chain.from_iterable(
File "/usr/local/lib/python3.10/concurrent/futures/process.py", line 575, in _chain_from_iterable_of_lists
for element in iterable:
File "/usr/local/lib/python3.10/concurrent/futures/_base.py", line 621, in result_iterator
yield _result_or_cancel(fs.pop())
File "/usr/local/lib/python3.10/concurrent/futures/_base.py", line 319, in _result_or_cancel
return fut.result(timeout)
File "/usr/local/lib/python3.10/concurrent/futures/_base.py", line 458, in result
return self.__get_result()
File "/usr/local/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
raise self._exception
concurrent.futures.process.BrokenProcessPool: A process in the process pool was terminated abruptly while the future was running or pending.
Thanks @bfolie . Do you know if this line is also affected?
I don't think so. .tuple_shapes() will work as long as all_gather_result has a tuple shape, and I think in that context it does. But I haven't tested it
Thanks @bfolie for the fix! I have confirmed that it works for Neuron.
For some reason multi-node all-gather is now crashing. Let me debug and isolate a testcase. The crash trace is below in case you know where to look:
F0616 05:45:23.533522 7492 shape.cc:184] Check failed: state Expected a tuple shape. Got bf16[4008,4096]
This is a programmer error. Please read the Shape object's tuple properties (e.g. tuple_shapes) only when it's a tuple shape.
*** Check failure stack trace: ***
@ 0x7ff941f4c039 absl::lts_20230802::log_internal::LogMessageFatal::~LogMessageFatal()
@ 0x7ff936e4fae2 xla::Shape::tuple_state()
@ 0x7ff941cd6bc9 xla::Shape::tuple_shapes()
@ 0x7ff9370baf9b torch_xla::BuildAllGatherCoalesced()
@ 0x7ff937410966 std::_Function_handler<>::_M_invoke()
@ 0x7ff93745fdf1 torch_xla::InferOutputShape()
@ 0x7ff937410b73 std::_Function_handler<>::_M_invoke()
@ 0x7ff9374f2d7f torch_xla::XlaNode::GetOpShape()
@ 0x7ff9374f3649 torch_xla::XlaNode::XlaNode()
@ 0x7ff937411a55 torch_xla::AllGatherCoalesced::AllGatherCoalesced()
@ 0x7ff93715d7fe torch_xla::MakeNode<>()
@ 0x7ff93715dacf torch_xla::tensor_methods::all_gather_coalesced()
@ 0x7ff936fc36c5 torch_xla::(anonymous namespace)::InitXlaModuleBindings()::{lambda()#45}::operator()()
@ 0x7ff936fee0bb pybind11::cpp_function::initialize<>()::{lambda()#3}::_FUN()
@ 0x7ff936fbf618 pybind11::cpp_function::dispatcher()
@ 0x564d9122de12 (unknown)
@bfolie https://github.com/pytorch/xla/pull/9403 is the fix for the allgather issue above. I will try to narrow down to a smaller unit test.
Narrowed down to single-node test, but it still has NeuronX Distributed and real dataset dependencies. Will narrow down some more.
Seems to happen with TP + ZeRO1.