jax icon indicating copy to clipboard operation
jax copied to clipboard

Using a device function inside the host function of host_callback fails confusingly

Open tomhennigan opened this issue 3 years ago • 6 comments

Using a JAX device function in the host fn of host_callback causes hangs/long compile/oom (breakage changes with the backend).

This not being supported is probably WAI, but we could provide a better error message to avoid user confusion. Minimal reproducer below:

import jax
import jax.numpy as jnp
import jax.experimental.host_callback as hcb
import numpy as np

def host_fun(m: np.ndarray) -> np.ndarray:
  return np.sum(m)  # works
  # return jnp.sum(m)  # causes errors

def device_fun(m):
  return hcb.call(host_fun, m,
                  result_shape=jax.ShapeDtypeStruct(m.shape, m.dtype))

jax.jit(device_fun)(0)

tomhennigan avatar Mar 04 '21 17:03 tomhennigan

The success of the reentrant call depends on the device configuration.

  • If there is only a single CPU device available, the call causes a deadlock.
  • If two or more CPU devices are explicitly made visible, one can create a working configuration by JITing device_fun onto one device and forcing host_fun onto a different device via the device kwarg of jit().
  • Similarly, a working configuration can be accomplished by spreading host_fun and device_fun across a CPU and a GPU device.

It is less clear what happens under transformations like pmap.

ahoenselaar avatar Mar 15 '21 16:03 ahoenselaar

Allowing re-entrant calls is a larger project, but I will look into providing a better message.

The rule is that when a callback executes on a device it blocks the device until the callback finishes. The callback cannot launch other computations on the same device.

gnecula avatar Mar 16 '21 06:03 gnecula

Related to this potential reentrant call, but looking at it from a more general perspective.

One of the potential use cases for jax.experimental.host_callback.call is to get around the jax.jit fixed shapes requirements for some specific intermediate computations, since anything that happens in the host_callback can take any arbitrary shape. This can be very valuable for prototyping and research code. Of course, using jax.experimental.host_callback.call for this purpose means:

  • Data to and from jax.experimental.host_callback.call needs to be copied from the device to the host and viceversa, which adds additional overhead.
  • There cannot be any device parallelization in the callback.

It would be fantastic if the inputs and outputs to the host_fun remained DeviceArray instead of numpy arrays. I understand this may cause complications for XLA optimization now that there is arbitrary host code being executed, but I could imagine adding a parameter to jax.experimental.host_callback.call, indicating how much device memory will be required. Is guess since host_callback is still experimental, it is not too late to think about whether we want this to be a pure "host" callback, or more a "return control to host" callback, or even something like jax.unjitted_call, but that may also make use of the device. E.g.:

@jax.jit
def fun(input):
  aux1 = input ** 2 + 5  # This will run jitted.
  
  def unjitted_fn(x):
    # This will  give back control to python, but keep the input array 'x' on device.
    print("host")  # This will print every single time.
    # users may still feed x to a CPU only library (e.g. scipy), and this will trigger automatic 
    # transfer of the array to host memory if required (as usual), but only if actually necessary.
    return x ** 2 + 1  # This will run un-jitted, but on device.
  aux2 = jax.unjitted_call(unjitted_fn, aux1, 
                           result_shape=jax.ShapeDtypeStruct(aux1.shape, aux1.dtype), 
                           max_memory=...)
                           
  output = aux2 **3 + 5  # This will run jitted
  return output

alvarosg avatar Mar 17 '21 15:03 alvarosg

Quick question regarding host_callback.call within jitted functions and multiprocessing... Say I have a machine with a GPU and 4 CPU cores. If I jit and vmap a fn that calls a python function through host_callback.call with batch_dim=4, is this a way to have the python functions run concurrently on all four CPU cores at once, or will JAX force running them all on the same CPU core one after another? Wondering if this is sort of a hack to get some concurrency with python functions through JAX.

evanatyourservice avatar Apr 26 '22 15:04 evanatyourservice

Inside a vmap computation you will see a single call to the host, with the entire batch. You'd have to write your host function to split the data into batches and run it in several threads.

On Tue, Apr 26, 2022, 16:28 Evan Walters @.***> wrote:

Quick question regarding host_callback.call within jitted functions and multiprocessing... Say I have a machine with a GPU and 4 CPU cores. If I jit and vmap a fn that calls a python function through host_callback.call with batch_dim=4, is this a way to have the python functions run concurrently on all four CPU cores at once, or will JAX force running them all on the same CPU core one after another? Wondering if this is sort of a hack to get some concurrency with python functions through JAX.

— Reply to this email directly, view it on GitHub https://github.com/google/jax/issues/5934#issuecomment-1109939638, or unsubscribe https://github.com/notifications/unsubscribe-auth/AA5V6J3IWWZNA6N6GSR4W3LVHADSNANCNFSM4YTUOI5A . You are receiving this because you were assigned.Message ID: @.***>

gnecula avatar Apr 27 '22 11:04 gnecula

@sharadmv please close this issue when the replacement is ready!

mattjj avatar Sep 07 '22 19:09 mattjj