jax
jax copied to clipboard
host_callback_test.py is failing on multi-GPU platforms
~/jax$ python3 tests/host_callback_test.py
fails with:
[ RUN ] CallJaxTest.test_jax_grad
2021-04-15 00:16:53.333864: E external/org_tensorflow/tensorflow/compiler/xla/status_macros.cc:56] Internal: RET_CHECK failure (external/org_tensorflow/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc:80) ShapeUtil::Equal(source_slices_[index].shape, output_shape) Mismatch between outfeed output buffer shape u32[2]{0} and outfeed source buffer shape f32[]
Mismatch between outfeed output buffer shape u32[2]{0} and outfeed source buffer shape f32[]
All test cases seem to fail this way. It's a little hard to tell because after a few test cases run and fail with this, the test will just hang until you kill it (probably because computations are blocked waiting for infeed).
I haven't dug into this at all, but for now I'm planning to disable all of host_callback_test on multi-GPU platforms in order to make our test suite runnable on multi-GPU (since right now it hangs).
I think infeed and outfeed aren't working on multi-GPU at the moment. I believe we end up with a single shared infeed queue rather than one per GPU.
Duplicate of #5577
Closing since duplicate