equinox icon indicating copy to clipboard operation
equinox copied to clipboard

Added `eqx.experimental.noinline`

Open patrick-kidger opened this issue 2 years ago • 8 comments

TL;DR: XLA sub-graphs!

Background

At present, JAX inlines the entire computation into a single XLA graph.

However, many scientific computing applications involve defining some modestly complicated function and then calling this function numerous times in different contexts. For example the vector field for an ODE must be traced 22 times when using the Kvaerno5 solver with automatic initial step size selection. (In ways that cannot easily be tidied up into a lax.scan or similar.)

Inlining without awareness of the repeated structure means the compiler is less efficient than it could be. I know of current examples with compile times about an hour long.

To support this use case there's been talk for quite a while about adding support to JAX or XLA for sub-programs, e.g. https://github.com/google/jax/issues/10974 https://github.com/google/jax/issues/4572 https://github.com/google/jax/issues/3847 https://github.com/google/jax/issues/9298

no-inline decorator

Introducing equinox.experimental.noinline. This decorator places the function in a separate XLA computation graph, and links it up with the main one via jax.experimental.host_callback.call. The decorated function is only traced once; only one copy of it exists as a jaxpr; it is only compiled once. It can still be transformed via grad, vmap etc.

Running the included benchmark benchmarks/noinline.py we obtain a reduction in compile time 36 seconds -> 25 seconds, at the expense of a large runtime increase, 0.002 seconds -> 0.6 seconds. In practice that's still a good net saving (36 seconds -> 25.6 seconds) in the common use-case that you're developing + debugging your program.

Going further and switching the solver in the benchmark from dfx.Kvaerno5() to dfx.Dopri8() gives even better results: a compile time reduction of 23 seconds -> 8 seconds (!), with a runtime increase of 0.002 seconds -> 0.16 seconds. (I chose not to implement this as the default benchmark, just because I have other unrelated plans for improving the compile time of Dopri8.)

Limitations

  1. All the awful hackery and monkey-patching of JAX internals needed to make this work.
  2. This will only be efficient on the CPU. On the GPU it'll entail copies to and from the device. However I speculate that this isn't actually necessary, and may just be a limitation of our current use of host_callback?
  3. The runtime performance has cratered. I speculate a lot of that cost is due to the back-and-forth switching via Python, again due to our use of host_callback. (Flame graphs TBD.) Possibly also something GIL related?

Nonetheless, our main use-case is on the CPU and the overall compile-time improvements on the benchmark represent compile speed improvements of 1.5x to 3x, which is enough to make me happy. This is something we're looking forward to relying on as those 1+ hour compile times are really biting us.

CC

@shoyer and @YouJiacheng as I know you've both wanted this functionality in the past. @federicov (+team) for using this.

Also speculatively tagging @gnecula (I have it in my head that you're behind host_callback?) @mattjj for possible interest. (Feel free to ignore this.)

patrick-kidger avatar Jul 01 '22 00:07 patrick-kidger

I wonder why we cannot implement this functionality by register user function as primitive with help of mlir.lower_fun and mlir.cache_lowering. Though it seems that in the MHLO=>HLO pass, mlir.cache_lowering wrapped function will still be expanded: https://github.com/google/jax/blob/118db407f24f20a77ef491d504290f8c29d57d05/jax/_src/lax/lax.py#L1952-L1955

And it might be possible to generate binary then wrap it with CustomCall which will never be inlined, without the cost of host_callback.

YouJiacheng avatar Jul 03 '22 14:07 YouJiacheng

Doing so with MLIR: right! After implementing the above I started wondering something similar. As an end goal, could we arrange to generate a single XlaComputation that all usages point at, via XLA's call operation? If need be wrap it in a fake while loop (same trick jax.checkpoint uses) to prevent any inlining analysis that might slow things down, but I assume just having a single computation object, rather than numerous, would already be enough to improve things sufficiently already.

The implications of this would be quite dramatic: it would enable a "bounded while loop" construct -- i.e. the core backbone of diffeq solvers, nonlinear optimisers, etc. This construct requires recursively tracing nested operations of the form lax.cond(pred, fn, jax.checkpoint(fn), *operands). The fact that they're nested means that at present we end up with an exponential explosion -- each use of the above involves tracing two copies of fn, each of which trace... etc. -- which makes this infeasible for compile-time performance. As such all current diffeq solvers / nonlinear optimisers / etc. currently hack around this and instead suffer substantial runtime performance penalties as a result. Being able to apply a noinline decorator to fn would avoid this issue, and thus drop the compile-time complexity from O(exp n) to just O(n). This would be amazing!

That said, a noinline implementation like the above sounds probably simple enough (if you know what you're doing), and as per the previous paragraph this feature request is clearly important enough, that I would have thought something like the above would exist in core JAX already by now. We're really bumping up against the edge of my knowledge wrt lowerings here, and perhaps this is impossible for some reason / entirely separate XLA graphs really are necessary.

patrick-kidger avatar Jul 03 '22 14:07 patrick-kidger

If wrap with rng fake while loop can successfully prevent inlining(really?), implement noinline should be easier. But I guess this trick can only prevent computation from moving outside. XLA can still clone the code at each call site.

YouJiacheng avatar Jul 03 '22 15:07 YouJiacheng

My trial:

from functools import partial, lru_cache

import jax
from jax import core
from jax.interpreters import mlir, xla, partial_eval as pe
from jax.tree_util import tree_flatten, tree_unflatten
import jax.linear_util as lu
from jax._src.api_util import flatten_fun

def foo(x):
    print('traced')
    return x + 1

@lru_cache
def abstract_eval_fun(fun, *args, **kwargs):
    args_flat, in_tree = tree_flatten((args, kwargs))
    wrapped_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
    out = pe.abstract_eval_fun(wrapped_fun.call_wrapped, *args_flat)
    return tree_unflatten(out_tree(), out)

foo_p = core.Primitive('foo')
foo_p.def_impl(partial(xla.apply_primitive, foo_p))
foo_p.def_abstract_eval(lambda x: abstract_eval_fun(foo, x))
mlir.register_lowering(foo_p, mlir.cache_lowering(mlir.lower_fun(foo, multiple_results=False)))


def bar(x):
    x = foo_p.bind(x)
    x = foo_p.bind(x)
    x = foo_p.bind(x)
    x = foo_p.bind(x)
    return x

print(jax.make_jaxpr(bar)(1))
print(jax.jit(bar).lower(1).compiler_ir()) # MHLO
print(jax.jit(bar).lower(1).compile().compiler_ir()[0].to_string()) # post-compilation HLO

Only inlined at post-compilation HLO level

traced
{ lambda ; a:i32[]. let
    b:i32[] = foo a
    c:i32[] = foo b
    d:i32[] = foo c
    e:i32[] = foo d
  in (e,) }
traced
module @jit_bar.0 {
  func public @main(%arg0: tensor<i32>) -> tensor<i32> {
    %0 = call @foo(%arg0) : (tensor<i32>) -> tensor<i32>
    %1 = call @foo(%0) : (tensor<i32>) -> tensor<i32>
    %2 = call @foo(%1) : (tensor<i32>) -> tensor<i32>
    %3 = call @foo(%2) : (tensor<i32>) -> tensor<i32>
    return %3 : tensor<i32>
  }
  func private @foo(%arg0: tensor<i32>) -> tensor<i32> {
    %0 = mhlo.constant dense<1> : tensor<i32>
    %1 = mhlo.add %arg0, %0 : tensor<i32>
    return %1 : tensor<i32>
  }
}

traced
HloModule jit_bar.1

ENTRY %main.22 (Arg_0.1: s32[]) -> s32[] {
  %Arg_0.1 = s32[] parameter(0)
  %constant_7 = s32[] constant(4)
  ROOT %add.3 = s32[] add(s32[] %Arg_0.1, s32[] %constant_7), metadata={op_name="jit(bar)/jit(main)/add" source_file="*" source_line=28}
}

YouJiacheng avatar Jul 03 '22 16:07 YouJiacheng

from functools import partial, lru_cache

import jax
from jax import core, lax
from jax.interpreters import mlir, xla, partial_eval as pe
from jax.tree_util import tree_flatten, tree_unflatten, tree_map
import jax.linear_util as lu
from jax._src.api_util import flatten_fun, shaped_abstractify

import numpy as np

def foo(x):
    print('traced')
    return x + 1

def abstract_eval_fun(fun, *args, **kwargs):
    args_flat, in_tree = tree_flatten((args, kwargs))
    wrapped_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
    out = pe.abstract_eval_fun(wrapped_fun.call_wrapped, *(shaped_abstractify(a) for a in args_flat))
    return tree_unflatten(out_tree(), out)

cached_abstract_eval_fun = lru_cache(abstract_eval_fun)

def _dummy_result(aval: core.AbstractValue):
  if aval is core.abstract_token:
    return lax.create_token()
  else:
    return lax.full(aval.shape, 0, aval.dtype)

def wrapped(fun, *args, **kwargs):
    avals_out = abstract_eval_fun(fun, *args, **kwargs)
    dummies_like_result = tree_map(_dummy_result, avals_out)
    carry_init = (np.int32(0), dummies_like_result, args, kwargs)
    def cond(carry):
        counter, _, _, _ = carry
        return counter < lax.rng_uniform(np.int32(1), np.int32(2), shape=())

    def body(carry):
        counter, _, args, kwargs = carry
        results = fun(*args, **kwargs)
        return (counter + 1, results, args, kwargs)

    carry_res = lax.while_loop(cond, body, carry_init)
    return carry_res[1]

foo_p = core.Primitive('foo')
foo_p.def_impl(partial(xla.apply_primitive, foo_p))
foo_p.def_abstract_eval(lambda x: cached_abstract_eval_fun(foo, x))
mlir.register_lowering(foo_p, mlir.cache_lowering(mlir.lower_fun(lambda x: wrapped(foo, x), multiple_results=False)))


def bar(x):
    x = foo_p.bind(x)
    x = foo_p.bind(x)
    x = foo_p.bind(x)
    x = foo_p.bind(x)
    return x

print(jax.make_jaxpr(bar)(1))
print(jax.jit(bar).lower(1).compiler_ir())
print(jax.jit(bar).lower(1).compile().compiler_ir()[0].to_string())

XLA STILL inline all foo -- copying while loop 4 times!!!

HloModule jit_bar.1

%region_0.2 (arg_tuple.3: (s32[], s32[], s32[])) -> (s32[], s32[], s32[]) {
  %arg_tuple.3 = (s32[], s32[], s32[]) parameter(0)
  %get-tuple-element.10 = s32[] get-tuple-element((s32[], s32[], s32[]) %arg_tuple.3), index=0
  %constant_7 = s32[] constant(1)
  %add.9 = s32[] add(s32[] %get-tuple-element.10, s32[] %constant_7), metadata={op_name="jit(bar)/jit(main)/while/body/add" source_file="/home/jiacheng/***/sub.py" source_line=54}
  %get-tuple-element.18 = s32[] get-tuple-element((s32[], s32[], s32[]) %arg_tuple.3), index=2
  %add.8 = s32[] add(s32[] %get-tuple-element.18, s32[] %constant_7), metadata={op_name="jit(bar)/jit(main)/while/body/add" source_file="/home/jiacheng/***/sub.py" source_line=54}
  ROOT %tuple.6 = (s32[], s32[], s32[]) tuple(s32[] %add.9, s32[] %add.8, s32[] %get-tuple-element.18)
}

%region_1.11 (arg_tuple.12: (s32[], s32[], s32[])) -> pred[] {
  %arg_tuple.12 = (s32[], s32[], s32[]) parameter(0)
  %get-tuple-element.13 = s32[] get-tuple-element((s32[], s32[], s32[]) %arg_tuple.12), index=0
  %constant_776 = s32[] constant(1)
  ROOT %compare.20 = pred[] compare(s32[] %get-tuple-element.13, s32[] %constant_776), direction=LT, metadata={op_name="jit(bar)/jit(main)/while/cond/lt" source_file="/home/jiacheng/***/sub.py" source_line=54}
  %rng-get-and-update-state = u64[2]{0} rng-get-and-update-state(), delta=1
}

%region_0.30 (arg_tuple.31: (s32[], s32[], s32[])) -> (s32[], s32[], s32[]) {
  %arg_tuple.31 = (s32[], s32[], s32[]) parameter(0)
  %get-tuple-element.22 = s32[] get-tuple-element((s32[], s32[], s32[]) %arg_tuple.31), index=0
  %constant_35 = s32[] constant(1)
  %add.37 = s32[] add(s32[] %get-tuple-element.22, s32[] %constant_35), metadata={op_name="jit(bar)/jit(main)/while/body/add" source_file="/home/jiacheng/***/sub.py" source_line=54}
  %get-tuple-element.30 = s32[] get-tuple-element((s32[], s32[], s32[]) %arg_tuple.31), index=2
  %add.36 = s32[] add(s32[] %get-tuple-element.30, s32[] %constant_35), metadata={op_name="jit(bar)/jit(main)/while/body/add" source_file="/home/jiacheng/***/sub.py" source_line=54}
  ROOT %tuple.9 = (s32[], s32[], s32[]) tuple(s32[] %add.37, s32[] %add.36, s32[] %get-tuple-element.30)
}

%region_1.39 (arg_tuple.40: (s32[], s32[], s32[])) -> pred[] {
  %arg_tuple.40 = (s32[], s32[], s32[]) parameter(0)
  %get-tuple-element.41 = s32[] get-tuple-element((s32[], s32[], s32[]) %arg_tuple.40), index=0
  %constant_782 = s32[] constant(1)
  ROOT %compare.48 = pred[] compare(s32[] %get-tuple-element.41, s32[] %constant_782), direction=LT, metadata={op_name="jit(bar)/jit(main)/while/cond/lt" source_file="/home/jiacheng/***/sub.py" source_line=54}
  %rng-get-and-update-state.1 = u64[2]{0} rng-get-and-update-state(), delta=1
}

%region_0.58 (arg_tuple.59: (s32[], s32[], s32[])) -> (s32[], s32[], s32[]) {
  %arg_tuple.59 = (s32[], s32[], s32[]) parameter(0)
  %get-tuple-element.37 = s32[] get-tuple-element((s32[], s32[], s32[]) %arg_tuple.59), index=0
  %constant_63 = s32[] constant(1)
  %add.65 = s32[] add(s32[] %get-tuple-element.37, s32[] %constant_63), metadata={op_name="jit(bar)/jit(main)/while/body/add" source_file="/home/jiacheng/***/sub.py" source_line=54}
  %get-tuple-element.45 = s32[] get-tuple-element((s32[], s32[], s32[]) %arg_tuple.59), index=2
  %add.64 = s32[] add(s32[] %get-tuple-element.45, s32[] %constant_63), metadata={op_name="jit(bar)/jit(main)/while/body/add" source_file="/home/jiacheng/***/sub.py" source_line=54}
  ROOT %tuple.13 = (s32[], s32[], s32[]) tuple(s32[] %add.65, s32[] %add.64, s32[] %get-tuple-element.45)
}

%region_1.67 (arg_tuple.68: (s32[], s32[], s32[])) -> pred[] {
  %arg_tuple.68 = (s32[], s32[], s32[]) parameter(0)
  %get-tuple-element.69 = s32[] get-tuple-element((s32[], s32[], s32[]) %arg_tuple.68), index=0
  %constant_788 = s32[] constant(1)
  ROOT %compare.76 = pred[] compare(s32[] %get-tuple-element.69, s32[] %constant_788), direction=LT, metadata={op_name="jit(bar)/jit(main)/while/cond/lt" source_file="/home/jiacheng/***/sub.py" source_line=54}
  %rng-get-and-update-state.2 = u64[2]{0} rng-get-and-update-state(), delta=1
}

%region_0.86 (arg_tuple.87: (s32[], s32[], s32[])) -> (s32[], s32[], s32[]) {
  %arg_tuple.87 = (s32[], s32[], s32[]) parameter(0)
  %get-tuple-element.49 = s32[] get-tuple-element((s32[], s32[], s32[]) %arg_tuple.87), index=0
  %constant_91 = s32[] constant(1)
  %add.93 = s32[] add(s32[] %get-tuple-element.49, s32[] %constant_91), metadata={op_name="jit(bar)/jit(main)/while/body/add" source_file="/home/jiacheng/***/sub.py" source_line=54}
  %get-tuple-element.57 = s32[] get-tuple-element((s32[], s32[], s32[]) %arg_tuple.87), index=2
  %add.92 = s32[] add(s32[] %get-tuple-element.57, s32[] %constant_91), metadata={op_name="jit(bar)/jit(main)/while/body/add" source_file="/home/jiacheng/***/sub.py" source_line=54}
  ROOT %tuple.16 = (s32[], s32[], s32[]) tuple(s32[] %add.93, s32[] %add.92, s32[] %get-tuple-element.57)
}

%region_1.95 (arg_tuple.96: (s32[], s32[], s32[])) -> pred[] {
  %arg_tuple.96 = (s32[], s32[], s32[]) parameter(0)
  %get-tuple-element.97 = s32[] get-tuple-element((s32[], s32[], s32[]) %arg_tuple.96), index=0
  %constant_794 = s32[] constant(1)
  ROOT %compare.104 = pred[] compare(s32[] %get-tuple-element.97, s32[] %constant_794), direction=LT, metadata={op_name="jit(bar)/jit(main)/while/cond/lt" source_file="/home/jiacheng/***/sub.py" source_line=54}
  %rng-get-and-update-state.3 = u64[2]{0} rng-get-and-update-state(), delta=1
}

ENTRY %main.114 (Arg_0.1: s32[]) -> s32[] {
  %constant_379 = s32[] constant(0)
  %copy.24 = s32[] copy(s32[] %constant_379)
  %copy.19 = s32[] copy(s32[] %copy.24)
  %copy.12 = s32[] copy(s32[] %copy.24)
  %copy.13 = s32[] copy(s32[] %copy.24)
  %copy.6 = s32[] copy(s32[] %copy.24)
  %copy.7 = s32[] copy(s32[] %copy.24)
  %copy = s32[] copy(s32[] %copy.24)
  %copy.1 = s32[] copy(s32[] %copy.24)
  %Arg_0.1 = s32[] parameter(0)
  %tuple.4 = (s32[], s32[], s32[]) tuple(s32[] %copy, s32[] %copy.1, s32[] %Arg_0.1)
  %while.0 = (s32[], s32[], s32[]) while((s32[], s32[], s32[]) %tuple.4), condition=%region_1.11, body=%region_0.2, metadata={op_name="jit(bar)/jit(main)/while[cond_nconsts=0 body_nconsts=0]" source_file="/home/jiacheng/***/sub.py" source_line=54}, backend_config="{\"known_trip_count\":{\"n\":\"1\"}}"
  %get-tuple-element.21 = s32[] get-tuple-element((s32[], s32[], s32[]) %while.0), index=1, metadata={op_name="jit(bar)/jit(main)/while[cond_nconsts=0 body_nconsts=0]" source_file="/home/jiacheng/***/sub.py" source_line=54}
  %tuple.7 = (s32[], s32[], s32[]) tuple(s32[] %copy.6, s32[] %copy.7, s32[] %get-tuple-element.21)
  %while.1 = (s32[], s32[], s32[]) while((s32[], s32[], s32[]) %tuple.7), condition=%region_1.39, body=%region_0.30, metadata={op_name="jit(bar)/jit(main)/while[cond_nconsts=0 body_nconsts=0]" source_file="/home/jiacheng/***/sub.py" source_line=54}, backend_config="{\"known_trip_count\":{\"n\":\"1\"}}"
  %get-tuple-element.36 = s32[] get-tuple-element((s32[], s32[], s32[]) %while.1), index=1, metadata={op_name="jit(bar)/jit(main)/while[cond_nconsts=0 body_nconsts=0]" source_file="/home/jiacheng/***/sub.py" source_line=54}
  %tuple.11 = (s32[], s32[], s32[]) tuple(s32[] %copy.12, s32[] %copy.13, s32[] %get-tuple-element.36)
  %while.2 = (s32[], s32[], s32[]) while((s32[], s32[], s32[]) %tuple.11), condition=%region_1.67, body=%region_0.58, metadata={op_name="jit(bar)/jit(main)/while[cond_nconsts=0 body_nconsts=0]" source_file="/home/jiacheng/***/sub.py" source_line=54}, backend_config="{\"known_trip_count\":{\"n\":\"1\"}}"
  %get-tuple-element.48 = s32[] get-tuple-element((s32[], s32[], s32[]) %while.2), index=1, metadata={op_name="jit(bar)/jit(main)/while[cond_nconsts=0 body_nconsts=0]" source_file="/home/jiacheng/***/sub.py" source_line=54}
  %tuple.14 = (s32[], s32[], s32[]) tuple(s32[] %copy.24, s32[] %copy.19, s32[] %get-tuple-element.48)
  %while.3 = (s32[], s32[], s32[]) while((s32[], s32[], s32[]) %tuple.14), condition=%region_1.95, body=%region_0.86, metadata={op_name="jit(bar)/jit(main)/while[cond_nconsts=0 body_nconsts=0]" source_file="/home/jiacheng/***/sub.py" source_line=54}, backend_config="{\"known_trip_count\":{\"n\":\"1\"}}"
  ROOT %get-tuple-element.3 = s32[] get-tuple-element((s32[], s32[], s32[]) %while.3), index=1, metadata={op_name="jit(bar)/jit(main)/while[cond_nconsts=0 body_nconsts=0]" source_file="/home/jiacheng/***/sub.py" source_line=54}
}

YouJiacheng avatar Jul 03 '22 17:07 YouJiacheng

If wrap with rng fake while loop can successfully prevent inlining(really?), implement noinline should be easier. But I guess this trick can only prevent computation from moving outside. XLA can still clone the code at each call site.

Right, so I think the fake while loop would prevent inlining at the HLO->compile-to-whatever-backend level, i.e. prevent optimisation of the function wrt its calling context. Indeed I didn't mean that it would prevent the cloning issue.

XLA STILL inline all foo -- copying while loop 4 times!!!

Bother. Looking at mlir.cache_lowering, it looks like this does exactly what I was hoping might work -- producing only a single function object and then re-using it whenever possible. That's really unfortunate that JAX/XLA seems to be limited like this.

Perhaps the way forward really is to implement a "better" version of host_callback that fits this use-case. The new work on effectful jaxprs might make this a lot easier -- these have so many use cases so I'm very excited to be getting algebraic effects in JAX -- since I assume this should remove a lot of the rewriting/hackery/etc. that host_callback seems to bring in.

patrick-kidger avatar Jul 03 '22 17:07 patrick-kidger

Hi! - I prototype a (possibly) simpler version than host_callback directly using mlir.emit_python_callback. But xla_python_gpu_callback using numpy to transfer data. Thanks to jaxlib-0.3.11 adding xla_python_gpu_callback, we don't need to use complex outfeed, at the cost of performance. (I found it is 20% faster than host_callback.call for (1024, 1024) float32 identity function, but spend nearly 3x time for (8, 1024, 1024) or larger float32)

from functools import lru_cache

import jax
from jax import core
from jax.interpreters import mlir, partial_eval as pe
from jax.tree_util import tree_flatten, tree_unflatten
import jax.linear_util as lu
from jax._src.api_util import flatten_fun

def foo(x):
    print('traced')
    print(type(x))
    return x + 1

def foo_lowered(x):
    print(type(x))
    return x + 1

@lru_cache
def abstract_eval_fun(fun, *args, **kwargs):
    args_flat, in_tree = tree_flatten((args, kwargs))
    wrapped_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
    out = pe.abstract_eval_fun(wrapped_fun.call_wrapped, *args_flat)
    return tree_unflatten(out_tree(), out)

callback_p = core.Primitive('callback')
callback_p.def_impl(lambda *args, callback, **_: callback(*args))
callback_p.def_abstract_eval(lambda *args, callback, **_: abstract_eval_fun(callback, *args))

def callback_lowering(ctx, *args, callback, callback_lowered):
    try:
        iter(abstract_eval_fun(callback, *ctx.avals_in))
    except TypeError:
        f = lambda *args: (callback_lowered(*args),)
    else:
        f = callback_lowered
    result, keepalive = mlir.emit_python_callback(ctx.module_context.platform, f, args, ctx.avals_in, ctx.avals_out, False)
    ctx.module_context.add_keepalive(keepalive)
    return result

mlir.register_lowering(callback_p, mlir.cache_lowering(callback_lowering))


def bar(x):
    x = callback_p.bind(x, callback=foo, callback_lowered=foo_lowered)
    x = callback_p.bind(x, callback=foo, callback_lowered=foo_lowered)
    x = callback_p.bind(x, callback=foo, callback_lowered=foo_lowered)
    x = callback_p.bind(x, callback=foo, callback_lowered=foo_lowered)
    return x

print(jax.make_jaxpr(bar)(1))
print(jax.jit(bar)(1))
print(jax.jit(bar).lower(1).compiler_ir()) # MHLO
print(jax.jit(bar).lower(1).compile().compiler_ir()[0].to_string()) # post-compilation HLO
traced
<class 'jax.interpreters.partial_eval.DynamicJaxprTracer'>
{ lambda ; a:i32[]. let
    b:i32[] = callback[
      callback=<function foo at 0x7f78a3d13790>
      callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>
    ] a
    c:i32[] = callback[
      callback=<function foo at 0x7f78a3d13790>
      callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>
    ] b
    d:i32[] = callback[
      callback=<function foo at 0x7f78a3d13790>
      callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>
    ] c
    e:i32[] = callback[
      callback=<function foo at 0x7f78a3d13790>
      callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>
    ] d
  in (e,) }
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
5
module @jit_bar.1 {
  func.func public @main(%arg0: tensor<i32>) -> tensor<i32> {
    %0 = call @callback(%arg0) : (tensor<i32>) -> tensor<i32>
    %1 = call @callback(%0) : (tensor<i32>) -> tensor<i32>
    %2 = call @callback(%1) : (tensor<i32>) -> tensor<i32>
    %3 = call @callback(%2) : (tensor<i32>) -> tensor<i32>
    return %3 : tensor<i32>
  }
  func.func private @callback(%arg0: tensor<i32>) -> tensor<i32> {
    %0 = mhlo.constant dense<94158890291952> : tensor<i64>
    %1 = "mhlo.custom_call"(%0, %arg0) {api_version = 2 : i32, backend_config = "94158890291952", call_target_name = "xla_python_gpu_callback", called_computations = [], has_side_effect = false} : (tensor<i64>, tensor<i32>) -> tuple<tensor<i32>>
    %2 = "mhlo.get_tuple_element"(%1) {index = 0 : i32} : (tuple<tensor<i32>>) -> tensor<i32>
    return %2 : tensor<i32>
  }
}

HloModule jit_bar.2, entry_computation_layout={(s32[])->s32[]}

ENTRY %main.26 (Arg_0.1: s32[]) -> s32[] {
  %constant_0 = s64[] constant(94158882601360)
  %Arg_0.1 = s32[] parameter(0)
  %custom-call.0 = (s32[]) custom-call(s64[] %constant_0, s32[] %Arg_0.1), custom_call_target="xla_python_gpu_callback", api_version=API_VERSION_STATUS_RETURNING, metadata={op_name="jit(bar)/jit(main)/callback[callback=<function foo at 0x7f78a3d13790> callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>]" source_file="/home/jiacheng/***/sub.py" source_line=45}, backend_config="94158882601360"
  %get-tuple-element.0 = s32[] get-tuple-element((s32[]) %custom-call.0), index=0, metadata={op_name="jit(bar)/jit(main)/callback[callback=<function foo at 0x7f78a3d13790> callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>]" source_file="/home/jiacheng/***/sub.py" source_line=45}
  %custom-call.1 = (s32[]) custom-call(s64[] %constant_0, s32[] %get-tuple-element.0), custom_call_target="xla_python_gpu_callback", api_version=API_VERSION_STATUS_RETURNING, metadata={op_name="jit(bar)/jit(main)/callback[callback=<function foo at 0x7f78a3d13790> callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>]" source_file="/home/jiacheng/***/sub.py" source_line=45}, backend_config="94158882601360"
  %get-tuple-element.1 = s32[] get-tuple-element((s32[]) %custom-call.1), index=0, metadata={op_name="jit(bar)/jit(main)/callback[callback=<function foo at 0x7f78a3d13790> callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>]" source_file="/home/jiacheng/***/sub.py" source_line=45}
  %custom-call.2 = (s32[]) custom-call(s64[] %constant_0, s32[] %get-tuple-element.1), custom_call_target="xla_python_gpu_callback", api_version=API_VERSION_STATUS_RETURNING, metadata={op_name="jit(bar)/jit(main)/callback[callback=<function foo at 0x7f78a3d13790> callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>]" source_file="/home/jiacheng/***/sub.py" source_line=45}, backend_config="94158882601360"
  %get-tuple-element.2 = s32[] get-tuple-element((s32[]) %custom-call.2), index=0, metadata={op_name="jit(bar)/jit(main)/callback[callback=<function foo at 0x7f78a3d13790> callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>]" source_file="/home/jiacheng/***/sub.py" source_line=45}
  %custom-call.3 = (s32[]) custom-call(s64[] %constant_0, s32[] %get-tuple-element.2), custom_call_target="xla_python_gpu_callback", api_version=API_VERSION_STATUS_RETURNING, metadata={op_name="jit(bar)/jit(main)/callback[callback=<function foo at 0x7f78a3d13790> callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>]" source_file="/home/jiacheng/***/sub.py" source_line=45}, backend_config="94158882601360"
  ROOT %get-tuple-element.3 = s32[] get-tuple-element((s32[]) %custom-call.3), index=0, metadata={op_name="jit(bar)/jit(main)/callback[callback=<function foo at 0x7f78a3d13790> callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>]" source_file="/home/jiacheng/***/sub.py" source_line=45}
}

YouJiacheng avatar Jul 06 '22 15:07 YouJiacheng

BTW, host_callback uses the CustomCall implementation already by default on CPU (no outfeed).

There is a new mechanism being developed to replace host_callback, but it is not yet ready.

In any case, using host_callback here seems more like a workaround than the best solution.

gnecula avatar Jul 06 '22 20:07 gnecula

Closing in favour of #218, which adds equinox.internal.noinline. This offers a heavily improved version of this.

  • Much faster.
  • Uses a custom primitive directly, instead of monkey-patching host_callback.call.
  • Offers the ability to only recompile only the part of a computation graph that has changed, without needing to recompile the entire graph. For example:
    def abstract(x, y):
        return jnp.broadcast_arrays(x, y)[0]
    
    def f(x, y):
        print("Compiling f!")
        return x + y
    
    def g(x, y):
        print("Compiling g!")
        return x * y
    
    f = noinline(f, abstract)
    g = noinline(g, abstract)
    
    def call(fn, x, y):
        print("Compiling call!")
        return fn(x, y)
    
    call = eqx.filter_jit(call)
    call(f, 1, 1)  # Compiling call! Compiling f!
    call(g, 1, 1)  # Compiling g!  [But does not recompile call!]
    

patrick-kidger avatar Nov 02 '22 23:11 patrick-kidger