jax icon indicating copy to clipboard operation
jax copied to clipboard

host_callback_test.py is failing on multi-GPU platforms

Open skye opened this issue 3 years ago • 2 comments

~/jax$ python3 tests/host_callback_test.py fails with:

[ RUN      ] CallJaxTest.test_jax_grad
2021-04-15 00:16:53.333864: E external/org_tensorflow/tensorflow/compiler/xla/status_macros.cc:56] Internal: RET_CHECK failure (external/org_tensorflow/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc:80) ShapeUtil::Equal(source_slices_[index].shape, output_shape) Mismatch between outfeed output buffer shape u32[2]{0} and outfeed source buffer shape f32[]

Mismatch between outfeed output buffer shape u32[2]{0} and outfeed source buffer shape f32[]

All test cases seem to fail this way. It's a little hard to tell because after a few test cases run and fail with this, the test will just hang until you kill it (probably because computations are blocked waiting for infeed).

I haven't dug into this at all, but for now I'm planning to disable all of host_callback_test on multi-GPU platforms in order to make our test suite runnable on multi-GPU (since right now it hangs).

skye avatar Apr 15 '21 00:04 skye

I think infeed and outfeed aren't working on multi-GPU at the moment. I believe we end up with a single shared infeed queue rather than one per GPU.

hawkinsp avatar Apr 15 '21 02:04 hawkinsp

Duplicate of #5577

gnecula avatar Apr 15 '21 08:04 gnecula

Closing since duplicate

sudhakarsingh27 avatar Aug 24 '22 19:08 sudhakarsingh27