jax icon indicating copy to clipboard operation
jax copied to clipboard

`host_callback.call` fails on multi-gpu machine

Open C-J-Cundy opened this issue 3 years ago • 4 comments

If I run the following code:

from jax.experimental import host_callback
import numpy as np
from jax import pmap, jit, partial, ShapeDtypeStruct


def host_fn(x):
    return x


x = np.ones(4, dtype=np.float32)
host_callback.call(host_fn, x, result_shape=x)

on a 2-gpu machine then it crashes with the error message

2021-01-31 17:53:40.778121: F external/org_tensorflow/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc:56] Check failed: ShapeUtil::Compatible(config_.input_shape, outfeed_buffers->shape()) XLA program outfeed request of shape (f32[4]) did not match the runtime's outfeed buffer of shape u32[2]

If I run with one GPU (by setting CUDA_VISIBLE_DEVICES=0) it finishes with no errors. Is there something I've missed in the documentation for host_callback about how it should be used on multi-device setups?

I ran both with the full debug information CUDA_VISIBLE_DEVICES=0 TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=outfeed_receiver=3,host_callback=3,outfeed_receiver_py=3,outfeed_thunk=3,xfeed_manager=3 python test_2.py --verbosity=2 2> test_output_one_gpu.txt if that's helpful. test_output_one_gpu.txt test_output_two_gpu.txt

C-J-Cundy avatar Feb 01 '21 02:02 C-J-Cundy

I think that this is not specific to multi-GPU, but can happen even with one GPU (randomly). I think it is related to #4374.

There are two fixes possible: fix the implementation of outfeed for XLA:GPU, or replace the implementation mechanism for GPUs to use CustomCall (this is in progress).

gnecula avatar Feb 02 '21 11:02 gnecula

Is this any closer to being fixed (or, ideas for a workaround?) host_callback is a really great addition to jax. It's a bit frustrating that it's currently not possible for me to use it with multiple GPUs.

C-J-Cundy avatar Mar 17 '21 23:03 C-J-Cundy

There are two updates. It turns out that the infeed/outfeed in XLA:GPU is not so easy to fix for multi-GPU. So that hope has gotten dimmer.

The second update is more positive, we have a new implementation in the works for GPU, using XLA CustomCall. This means that the host_callback will be synchronous. This implementation was blocked on GPU due to another XLA bug that has been fixed. So the plan is to enable this second implementation mechanism, choosable with an environment variable and command-line flag. This change involves both Python and C++ and will take at least a couple of weeks to land. Sorry for the delay!

gnecula avatar Mar 18 '21 08:03 gnecula

Sorry for the bump, but what's the current status of the second update?

AllanChain avatar Feb 07 '22 11:02 AllanChain

Sorry for the bump, but what's the current status of the second update?

The custom call on GPU is now landed but not used in host callback quite yet. You can try out the new callback mechanism on GPU with jax.debug.print and we should be porting HCB to use the new custom call very soon.

sharadmv avatar Aug 13 '22 17:08 sharadmv

@C-J-Cundy @AllanChain can you say more about your intended use case? For example, is it to have a callback for a debugging side-effect (like printing), or to perform some functionally pure numerical computation (on the host?), or something else? I ask because if it's one of those two applications we can recommend a replacement API (without having to wait for porting the HCB API to use a new implementation).

mattjj avatar Aug 24 '22 19:08 mattjj

functionally pure numerical computation (on the host?),

What would be the replacement API in that case?

PhilipVinc avatar Aug 24 '22 21:08 PhilipVinc

It is jax.pure_callback.

sharadmv avatar Aug 25 '22 03:08 sharadmv

is it to have a callback for a debugging side-effect (like printing), or to perform some functionally pure numerical computation (on the host?)

Was the former. But I have figured out my problem and not waiting for this anymore.

AllanChain avatar Aug 25 '22 06:08 AllanChain

For reference, the "callback for a debugging side-effect (like printing)" is jax.debug.callback

sharadmv avatar Aug 25 '22 22:08 sharadmv