jax
jax copied to clipboard
Using a device function inside the host function of host_callback fails confusingly
Using a JAX device function in the host fn of host_callback
causes hangs/long compile/oom (breakage changes with the backend).
This not being supported is probably WAI, but we could provide a better error message to avoid user confusion. Minimal reproducer below:
import jax
import jax.numpy as jnp
import jax.experimental.host_callback as hcb
import numpy as np
def host_fun(m: np.ndarray) -> np.ndarray:
return np.sum(m) # works
# return jnp.sum(m) # causes errors
def device_fun(m):
return hcb.call(host_fun, m,
result_shape=jax.ShapeDtypeStruct(m.shape, m.dtype))
jax.jit(device_fun)(0)
The success of the reentrant call depends on the device configuration.
- If there is only a single CPU device available, the call causes a deadlock.
- If two or more CPU devices are explicitly made visible, one can create a working configuration by JITing
device_fun
onto one device and forcinghost_fun
onto a different device via thedevice
kwarg ofjit()
. - Similarly, a working configuration can be accomplished by spreading
host_fun
anddevice_fun
across a CPU and a GPU device.
It is less clear what happens under transformations like pmap
.
Allowing re-entrant calls is a larger project, but I will look into providing a better message.
The rule is that when a callback executes on a device it blocks the device until the callback finishes. The callback cannot launch other computations on the same device.
Related to this potential reentrant call, but looking at it from a more general perspective.
One of the potential use cases for jax.experimental.host_callback.call
is to get around the jax.jit
fixed shapes requirements for some specific intermediate computations, since anything that happens in the host_callback can take any arbitrary shape. This can be very valuable for prototyping and research code. Of course, using jax.experimental.host_callback.call
for this purpose means:
- Data to and from
jax.experimental.host_callback.call
needs to be copied from the device to the host and viceversa, which adds additional overhead. - There cannot be any device parallelization in the callback.
It would be fantastic if the inputs and outputs to the host_fun
remained DeviceArray instead of numpy arrays. I understand this may cause complications for XLA optimization now that there is arbitrary host code being executed, but I could imagine adding a parameter to jax.experimental.host_callback.call
, indicating how much device memory will be required. Is guess since host_callback
is still experimental, it is not too late to think about whether we want this to be a pure "host" callback, or more a "return control to host" callback, or even something like jax.unjitted_call
, but that may also make use of the device. E.g.:
@jax.jit
def fun(input):
aux1 = input ** 2 + 5 # This will run jitted.
def unjitted_fn(x):
# This will give back control to python, but keep the input array 'x' on device.
print("host") # This will print every single time.
# users may still feed x to a CPU only library (e.g. scipy), and this will trigger automatic
# transfer of the array to host memory if required (as usual), but only if actually necessary.
return x ** 2 + 1 # This will run un-jitted, but on device.
aux2 = jax.unjitted_call(unjitted_fn, aux1,
result_shape=jax.ShapeDtypeStruct(aux1.shape, aux1.dtype),
max_memory=...)
output = aux2 **3 + 5 # This will run jitted
return output
Quick question regarding host_callback.call within jitted functions and multiprocessing... Say I have a machine with a GPU and 4 CPU cores. If I jit and vmap a fn that calls a python function through host_callback.call with batch_dim=4, is this a way to have the python functions run concurrently on all four CPU cores at once, or will JAX force running them all on the same CPU core one after another? Wondering if this is sort of a hack to get some concurrency with python functions through JAX.
Inside a vmap computation you will see a single call to the host, with the entire batch. You'd have to write your host function to split the data into batches and run it in several threads.
On Tue, Apr 26, 2022, 16:28 Evan Walters @.***> wrote:
Quick question regarding host_callback.call within jitted functions and multiprocessing... Say I have a machine with a GPU and 4 CPU cores. If I jit and vmap a fn that calls a python function through host_callback.call with batch_dim=4, is this a way to have the python functions run concurrently on all four CPU cores at once, or will JAX force running them all on the same CPU core one after another? Wondering if this is sort of a hack to get some concurrency with python functions through JAX.
— Reply to this email directly, view it on GitHub https://github.com/google/jax/issues/5934#issuecomment-1109939638, or unsubscribe https://github.com/notifications/unsubscribe-auth/AA5V6J3IWWZNA6N6GSR4W3LVHADSNANCNFSM4YTUOI5A . You are receiving this because you were assigned.Message ID: @.***>
@sharadmv please close this issue when the replacement is ready!