equinox
equinox copied to clipboard
Added `eqx.experimental.noinline`
TL;DR: XLA sub-graphs!
Background
At present, JAX inlines the entire computation into a single XLA graph.
However, many scientific computing applications involve defining some modestly complicated function and then calling this function numerous times in different contexts. For example the vector field for an ODE must be traced 22 times when using the Kvaerno5 solver with automatic initial step size selection. (In ways that cannot easily be tidied up into a lax.scan
or similar.)
Inlining without awareness of the repeated structure means the compiler is less efficient than it could be. I know of current examples with compile times about an hour long.
To support this use case there's been talk for quite a while about adding support to JAX or XLA for sub-programs, e.g. https://github.com/google/jax/issues/10974 https://github.com/google/jax/issues/4572 https://github.com/google/jax/issues/3847 https://github.com/google/jax/issues/9298
no-inline decorator
Introducing equinox.experimental.noinline
. This decorator places the function in a separate XLA computation graph, and links it up with the main one via jax.experimental.host_callback.call
. The decorated function is only traced once; only one copy of it exists as a jaxpr; it is only compiled once. It can still be transformed via grad, vmap etc.
Running the included benchmark benchmarks/noinline.py
we obtain a reduction in compile time 36 seconds -> 25 seconds, at the expense of a large runtime increase, 0.002 seconds -> 0.6 seconds. In practice that's still a good net saving (36 seconds -> 25.6 seconds) in the common use-case that you're developing + debugging your program.
Going further and switching the solver in the benchmark from dfx.Kvaerno5()
to dfx.Dopri8()
gives even better results: a compile time reduction of 23 seconds -> 8 seconds (!), with a runtime increase of 0.002 seconds -> 0.16 seconds. (I chose not to implement this as the default benchmark, just because I have other unrelated plans for improving the compile time of Dopri8
.)
Limitations
- All the awful hackery and monkey-patching of JAX internals needed to make this work.
- This will only be efficient on the CPU. On the GPU it'll entail copies to and from the device. However I speculate that this isn't actually necessary, and may just be a limitation of our current use of
host_callback
? - The runtime performance has cratered. I speculate a lot of that cost is due to the back-and-forth switching via Python, again due to our use of
host_callback
. (Flame graphs TBD.) Possibly also something GIL related?
Nonetheless, our main use-case is on the CPU and the overall compile-time improvements on the benchmark represent compile speed improvements of 1.5x to 3x, which is enough to make me happy. This is something we're looking forward to relying on as those 1+ hour compile times are really biting us.
CC
@shoyer and @YouJiacheng as I know you've both wanted this functionality in the past. @federicov (+team) for using this.
Also speculatively tagging @gnecula (I have it in my head that you're behind host_callback
?) @mattjj for possible interest. (Feel free to ignore this.)
I wonder why we cannot implement this functionality by register user function as primitive with help of mlir.lower_fun
and mlir.cache_lowering
.
Though it seems that in the MHLO=>HLO pass, mlir.cache_lowering
wrapped function will still be expanded:
https://github.com/google/jax/blob/118db407f24f20a77ef491d504290f8c29d57d05/jax/_src/lax/lax.py#L1952-L1955
And it might be possible to generate binary then wrap it with CustomCall
which will never be inlined, without the cost of host_callback
.
Doing so with MLIR: right! After implementing the above I started wondering something similar. As an end goal, could we arrange to generate a single XlaComputation
that all usages point at, via XLA's call
operation? If need be wrap it in a fake while loop (same trick jax.checkpoint
uses) to prevent any inlining analysis that might slow things down, but I assume just having a single computation object, rather than numerous, would already be enough to improve things sufficiently already.
The implications of this would be quite dramatic: it would enable a "bounded while loop" construct -- i.e. the core backbone of diffeq solvers, nonlinear optimisers, etc. This construct requires recursively tracing nested operations of the form lax.cond(pred, fn, jax.checkpoint(fn), *operands)
. The fact that they're nested means that at present we end up with an exponential explosion -- each use of the above involves tracing two copies of fn
, each of which trace... etc. -- which makes this infeasible for compile-time performance. As such all current diffeq solvers / nonlinear optimisers / etc. currently hack around this and instead suffer substantial runtime performance penalties as a result. Being able to apply a noinline
decorator to fn
would avoid this issue, and thus drop the compile-time complexity from O(exp n)
to just O(n)
. This would be amazing!
That said, a noinline
implementation like the above sounds probably simple enough (if you know what you're doing), and as per the previous paragraph this feature request is clearly important enough, that I would have thought something like the above would exist in core JAX already by now. We're really bumping up against the edge of my knowledge wrt lowerings here, and perhaps this is impossible for some reason / entirely separate XLA graphs really are necessary.
If wrap with rng fake while loop can successfully prevent inlining(really?), implement noinline
should be easier.
But I guess this trick can only prevent computation from moving outside. XLA can still clone the code at each call site.
My trial:
from functools import partial, lru_cache
import jax
from jax import core
from jax.interpreters import mlir, xla, partial_eval as pe
from jax.tree_util import tree_flatten, tree_unflatten
import jax.linear_util as lu
from jax._src.api_util import flatten_fun
def foo(x):
print('traced')
return x + 1
@lru_cache
def abstract_eval_fun(fun, *args, **kwargs):
args_flat, in_tree = tree_flatten((args, kwargs))
wrapped_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
out = pe.abstract_eval_fun(wrapped_fun.call_wrapped, *args_flat)
return tree_unflatten(out_tree(), out)
foo_p = core.Primitive('foo')
foo_p.def_impl(partial(xla.apply_primitive, foo_p))
foo_p.def_abstract_eval(lambda x: abstract_eval_fun(foo, x))
mlir.register_lowering(foo_p, mlir.cache_lowering(mlir.lower_fun(foo, multiple_results=False)))
def bar(x):
x = foo_p.bind(x)
x = foo_p.bind(x)
x = foo_p.bind(x)
x = foo_p.bind(x)
return x
print(jax.make_jaxpr(bar)(1))
print(jax.jit(bar).lower(1).compiler_ir()) # MHLO
print(jax.jit(bar).lower(1).compile().compiler_ir()[0].to_string()) # post-compilation HLO
Only inlined at post-compilation HLO level
traced
{ lambda ; a:i32[]. let
b:i32[] = foo a
c:i32[] = foo b
d:i32[] = foo c
e:i32[] = foo d
in (e,) }
traced
module @jit_bar.0 {
func public @main(%arg0: tensor<i32>) -> tensor<i32> {
%0 = call @foo(%arg0) : (tensor<i32>) -> tensor<i32>
%1 = call @foo(%0) : (tensor<i32>) -> tensor<i32>
%2 = call @foo(%1) : (tensor<i32>) -> tensor<i32>
%3 = call @foo(%2) : (tensor<i32>) -> tensor<i32>
return %3 : tensor<i32>
}
func private @foo(%arg0: tensor<i32>) -> tensor<i32> {
%0 = mhlo.constant dense<1> : tensor<i32>
%1 = mhlo.add %arg0, %0 : tensor<i32>
return %1 : tensor<i32>
}
}
traced
HloModule jit_bar.1
ENTRY %main.22 (Arg_0.1: s32[]) -> s32[] {
%Arg_0.1 = s32[] parameter(0)
%constant_7 = s32[] constant(4)
ROOT %add.3 = s32[] add(s32[] %Arg_0.1, s32[] %constant_7), metadata={op_name="jit(bar)/jit(main)/add" source_file="*" source_line=28}
}
from functools import partial, lru_cache
import jax
from jax import core, lax
from jax.interpreters import mlir, xla, partial_eval as pe
from jax.tree_util import tree_flatten, tree_unflatten, tree_map
import jax.linear_util as lu
from jax._src.api_util import flatten_fun, shaped_abstractify
import numpy as np
def foo(x):
print('traced')
return x + 1
def abstract_eval_fun(fun, *args, **kwargs):
args_flat, in_tree = tree_flatten((args, kwargs))
wrapped_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
out = pe.abstract_eval_fun(wrapped_fun.call_wrapped, *(shaped_abstractify(a) for a in args_flat))
return tree_unflatten(out_tree(), out)
cached_abstract_eval_fun = lru_cache(abstract_eval_fun)
def _dummy_result(aval: core.AbstractValue):
if aval is core.abstract_token:
return lax.create_token()
else:
return lax.full(aval.shape, 0, aval.dtype)
def wrapped(fun, *args, **kwargs):
avals_out = abstract_eval_fun(fun, *args, **kwargs)
dummies_like_result = tree_map(_dummy_result, avals_out)
carry_init = (np.int32(0), dummies_like_result, args, kwargs)
def cond(carry):
counter, _, _, _ = carry
return counter < lax.rng_uniform(np.int32(1), np.int32(2), shape=())
def body(carry):
counter, _, args, kwargs = carry
results = fun(*args, **kwargs)
return (counter + 1, results, args, kwargs)
carry_res = lax.while_loop(cond, body, carry_init)
return carry_res[1]
foo_p = core.Primitive('foo')
foo_p.def_impl(partial(xla.apply_primitive, foo_p))
foo_p.def_abstract_eval(lambda x: cached_abstract_eval_fun(foo, x))
mlir.register_lowering(foo_p, mlir.cache_lowering(mlir.lower_fun(lambda x: wrapped(foo, x), multiple_results=False)))
def bar(x):
x = foo_p.bind(x)
x = foo_p.bind(x)
x = foo_p.bind(x)
x = foo_p.bind(x)
return x
print(jax.make_jaxpr(bar)(1))
print(jax.jit(bar).lower(1).compiler_ir())
print(jax.jit(bar).lower(1).compile().compiler_ir()[0].to_string())
XLA STILL inline all foo -- copying while loop 4 times!!!
HloModule jit_bar.1
%region_0.2 (arg_tuple.3: (s32[], s32[], s32[])) -> (s32[], s32[], s32[]) {
%arg_tuple.3 = (s32[], s32[], s32[]) parameter(0)
%get-tuple-element.10 = s32[] get-tuple-element((s32[], s32[], s32[]) %arg_tuple.3), index=0
%constant_7 = s32[] constant(1)
%add.9 = s32[] add(s32[] %get-tuple-element.10, s32[] %constant_7), metadata={op_name="jit(bar)/jit(main)/while/body/add" source_file="/home/jiacheng/***/sub.py" source_line=54}
%get-tuple-element.18 = s32[] get-tuple-element((s32[], s32[], s32[]) %arg_tuple.3), index=2
%add.8 = s32[] add(s32[] %get-tuple-element.18, s32[] %constant_7), metadata={op_name="jit(bar)/jit(main)/while/body/add" source_file="/home/jiacheng/***/sub.py" source_line=54}
ROOT %tuple.6 = (s32[], s32[], s32[]) tuple(s32[] %add.9, s32[] %add.8, s32[] %get-tuple-element.18)
}
%region_1.11 (arg_tuple.12: (s32[], s32[], s32[])) -> pred[] {
%arg_tuple.12 = (s32[], s32[], s32[]) parameter(0)
%get-tuple-element.13 = s32[] get-tuple-element((s32[], s32[], s32[]) %arg_tuple.12), index=0
%constant_776 = s32[] constant(1)
ROOT %compare.20 = pred[] compare(s32[] %get-tuple-element.13, s32[] %constant_776), direction=LT, metadata={op_name="jit(bar)/jit(main)/while/cond/lt" source_file="/home/jiacheng/***/sub.py" source_line=54}
%rng-get-and-update-state = u64[2]{0} rng-get-and-update-state(), delta=1
}
%region_0.30 (arg_tuple.31: (s32[], s32[], s32[])) -> (s32[], s32[], s32[]) {
%arg_tuple.31 = (s32[], s32[], s32[]) parameter(0)
%get-tuple-element.22 = s32[] get-tuple-element((s32[], s32[], s32[]) %arg_tuple.31), index=0
%constant_35 = s32[] constant(1)
%add.37 = s32[] add(s32[] %get-tuple-element.22, s32[] %constant_35), metadata={op_name="jit(bar)/jit(main)/while/body/add" source_file="/home/jiacheng/***/sub.py" source_line=54}
%get-tuple-element.30 = s32[] get-tuple-element((s32[], s32[], s32[]) %arg_tuple.31), index=2
%add.36 = s32[] add(s32[] %get-tuple-element.30, s32[] %constant_35), metadata={op_name="jit(bar)/jit(main)/while/body/add" source_file="/home/jiacheng/***/sub.py" source_line=54}
ROOT %tuple.9 = (s32[], s32[], s32[]) tuple(s32[] %add.37, s32[] %add.36, s32[] %get-tuple-element.30)
}
%region_1.39 (arg_tuple.40: (s32[], s32[], s32[])) -> pred[] {
%arg_tuple.40 = (s32[], s32[], s32[]) parameter(0)
%get-tuple-element.41 = s32[] get-tuple-element((s32[], s32[], s32[]) %arg_tuple.40), index=0
%constant_782 = s32[] constant(1)
ROOT %compare.48 = pred[] compare(s32[] %get-tuple-element.41, s32[] %constant_782), direction=LT, metadata={op_name="jit(bar)/jit(main)/while/cond/lt" source_file="/home/jiacheng/***/sub.py" source_line=54}
%rng-get-and-update-state.1 = u64[2]{0} rng-get-and-update-state(), delta=1
}
%region_0.58 (arg_tuple.59: (s32[], s32[], s32[])) -> (s32[], s32[], s32[]) {
%arg_tuple.59 = (s32[], s32[], s32[]) parameter(0)
%get-tuple-element.37 = s32[] get-tuple-element((s32[], s32[], s32[]) %arg_tuple.59), index=0
%constant_63 = s32[] constant(1)
%add.65 = s32[] add(s32[] %get-tuple-element.37, s32[] %constant_63), metadata={op_name="jit(bar)/jit(main)/while/body/add" source_file="/home/jiacheng/***/sub.py" source_line=54}
%get-tuple-element.45 = s32[] get-tuple-element((s32[], s32[], s32[]) %arg_tuple.59), index=2
%add.64 = s32[] add(s32[] %get-tuple-element.45, s32[] %constant_63), metadata={op_name="jit(bar)/jit(main)/while/body/add" source_file="/home/jiacheng/***/sub.py" source_line=54}
ROOT %tuple.13 = (s32[], s32[], s32[]) tuple(s32[] %add.65, s32[] %add.64, s32[] %get-tuple-element.45)
}
%region_1.67 (arg_tuple.68: (s32[], s32[], s32[])) -> pred[] {
%arg_tuple.68 = (s32[], s32[], s32[]) parameter(0)
%get-tuple-element.69 = s32[] get-tuple-element((s32[], s32[], s32[]) %arg_tuple.68), index=0
%constant_788 = s32[] constant(1)
ROOT %compare.76 = pred[] compare(s32[] %get-tuple-element.69, s32[] %constant_788), direction=LT, metadata={op_name="jit(bar)/jit(main)/while/cond/lt" source_file="/home/jiacheng/***/sub.py" source_line=54}
%rng-get-and-update-state.2 = u64[2]{0} rng-get-and-update-state(), delta=1
}
%region_0.86 (arg_tuple.87: (s32[], s32[], s32[])) -> (s32[], s32[], s32[]) {
%arg_tuple.87 = (s32[], s32[], s32[]) parameter(0)
%get-tuple-element.49 = s32[] get-tuple-element((s32[], s32[], s32[]) %arg_tuple.87), index=0
%constant_91 = s32[] constant(1)
%add.93 = s32[] add(s32[] %get-tuple-element.49, s32[] %constant_91), metadata={op_name="jit(bar)/jit(main)/while/body/add" source_file="/home/jiacheng/***/sub.py" source_line=54}
%get-tuple-element.57 = s32[] get-tuple-element((s32[], s32[], s32[]) %arg_tuple.87), index=2
%add.92 = s32[] add(s32[] %get-tuple-element.57, s32[] %constant_91), metadata={op_name="jit(bar)/jit(main)/while/body/add" source_file="/home/jiacheng/***/sub.py" source_line=54}
ROOT %tuple.16 = (s32[], s32[], s32[]) tuple(s32[] %add.93, s32[] %add.92, s32[] %get-tuple-element.57)
}
%region_1.95 (arg_tuple.96: (s32[], s32[], s32[])) -> pred[] {
%arg_tuple.96 = (s32[], s32[], s32[]) parameter(0)
%get-tuple-element.97 = s32[] get-tuple-element((s32[], s32[], s32[]) %arg_tuple.96), index=0
%constant_794 = s32[] constant(1)
ROOT %compare.104 = pred[] compare(s32[] %get-tuple-element.97, s32[] %constant_794), direction=LT, metadata={op_name="jit(bar)/jit(main)/while/cond/lt" source_file="/home/jiacheng/***/sub.py" source_line=54}
%rng-get-and-update-state.3 = u64[2]{0} rng-get-and-update-state(), delta=1
}
ENTRY %main.114 (Arg_0.1: s32[]) -> s32[] {
%constant_379 = s32[] constant(0)
%copy.24 = s32[] copy(s32[] %constant_379)
%copy.19 = s32[] copy(s32[] %copy.24)
%copy.12 = s32[] copy(s32[] %copy.24)
%copy.13 = s32[] copy(s32[] %copy.24)
%copy.6 = s32[] copy(s32[] %copy.24)
%copy.7 = s32[] copy(s32[] %copy.24)
%copy = s32[] copy(s32[] %copy.24)
%copy.1 = s32[] copy(s32[] %copy.24)
%Arg_0.1 = s32[] parameter(0)
%tuple.4 = (s32[], s32[], s32[]) tuple(s32[] %copy, s32[] %copy.1, s32[] %Arg_0.1)
%while.0 = (s32[], s32[], s32[]) while((s32[], s32[], s32[]) %tuple.4), condition=%region_1.11, body=%region_0.2, metadata={op_name="jit(bar)/jit(main)/while[cond_nconsts=0 body_nconsts=0]" source_file="/home/jiacheng/***/sub.py" source_line=54}, backend_config="{\"known_trip_count\":{\"n\":\"1\"}}"
%get-tuple-element.21 = s32[] get-tuple-element((s32[], s32[], s32[]) %while.0), index=1, metadata={op_name="jit(bar)/jit(main)/while[cond_nconsts=0 body_nconsts=0]" source_file="/home/jiacheng/***/sub.py" source_line=54}
%tuple.7 = (s32[], s32[], s32[]) tuple(s32[] %copy.6, s32[] %copy.7, s32[] %get-tuple-element.21)
%while.1 = (s32[], s32[], s32[]) while((s32[], s32[], s32[]) %tuple.7), condition=%region_1.39, body=%region_0.30, metadata={op_name="jit(bar)/jit(main)/while[cond_nconsts=0 body_nconsts=0]" source_file="/home/jiacheng/***/sub.py" source_line=54}, backend_config="{\"known_trip_count\":{\"n\":\"1\"}}"
%get-tuple-element.36 = s32[] get-tuple-element((s32[], s32[], s32[]) %while.1), index=1, metadata={op_name="jit(bar)/jit(main)/while[cond_nconsts=0 body_nconsts=0]" source_file="/home/jiacheng/***/sub.py" source_line=54}
%tuple.11 = (s32[], s32[], s32[]) tuple(s32[] %copy.12, s32[] %copy.13, s32[] %get-tuple-element.36)
%while.2 = (s32[], s32[], s32[]) while((s32[], s32[], s32[]) %tuple.11), condition=%region_1.67, body=%region_0.58, metadata={op_name="jit(bar)/jit(main)/while[cond_nconsts=0 body_nconsts=0]" source_file="/home/jiacheng/***/sub.py" source_line=54}, backend_config="{\"known_trip_count\":{\"n\":\"1\"}}"
%get-tuple-element.48 = s32[] get-tuple-element((s32[], s32[], s32[]) %while.2), index=1, metadata={op_name="jit(bar)/jit(main)/while[cond_nconsts=0 body_nconsts=0]" source_file="/home/jiacheng/***/sub.py" source_line=54}
%tuple.14 = (s32[], s32[], s32[]) tuple(s32[] %copy.24, s32[] %copy.19, s32[] %get-tuple-element.48)
%while.3 = (s32[], s32[], s32[]) while((s32[], s32[], s32[]) %tuple.14), condition=%region_1.95, body=%region_0.86, metadata={op_name="jit(bar)/jit(main)/while[cond_nconsts=0 body_nconsts=0]" source_file="/home/jiacheng/***/sub.py" source_line=54}, backend_config="{\"known_trip_count\":{\"n\":\"1\"}}"
ROOT %get-tuple-element.3 = s32[] get-tuple-element((s32[], s32[], s32[]) %while.3), index=1, metadata={op_name="jit(bar)/jit(main)/while[cond_nconsts=0 body_nconsts=0]" source_file="/home/jiacheng/***/sub.py" source_line=54}
}
If wrap with rng fake while loop can successfully prevent inlining(really?), implement
noinline
should be easier. But I guess this trick can only prevent computation from moving outside. XLA can still clone the code at each call site.
Right, so I think the fake while loop would prevent inlining at the HLO->compile-to-whatever-backend level, i.e. prevent optimisation of the function wrt its calling context. Indeed I didn't mean that it would prevent the cloning issue.
XLA STILL inline all foo -- copying while loop 4 times!!!
Bother. Looking at mlir.cache_lowering
, it looks like this does exactly what I was hoping might work -- producing only a single function object and then re-using it whenever possible. That's really unfortunate that JAX/XLA seems to be limited like this.
Perhaps the way forward really is to implement a "better" version of host_callback
that fits this use-case. The new work on effectful jaxprs might make this a lot easier -- these have so many use cases so I'm very excited to be getting algebraic effects in JAX -- since I assume this should remove a lot of the rewriting/hackery/etc. that host_callback
seems to bring in.
Hi! - I prototype a (possibly) simpler version than host_callback
directly using mlir.emit_python_callback
.
But xla_python_gpu_callback
using numpy to transfer data.
Thanks to jaxlib-0.3.11 adding xla_python_gpu_callback
, we don't need to use complex outfeed, at the cost of performance.
(I found it is 20% faster than host_callback.call
for (1024, 1024)
float32 identity function, but spend nearly 3x time for (8, 1024, 1024)
or larger float32)
from functools import lru_cache
import jax
from jax import core
from jax.interpreters import mlir, partial_eval as pe
from jax.tree_util import tree_flatten, tree_unflatten
import jax.linear_util as lu
from jax._src.api_util import flatten_fun
def foo(x):
print('traced')
print(type(x))
return x + 1
def foo_lowered(x):
print(type(x))
return x + 1
@lru_cache
def abstract_eval_fun(fun, *args, **kwargs):
args_flat, in_tree = tree_flatten((args, kwargs))
wrapped_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
out = pe.abstract_eval_fun(wrapped_fun.call_wrapped, *args_flat)
return tree_unflatten(out_tree(), out)
callback_p = core.Primitive('callback')
callback_p.def_impl(lambda *args, callback, **_: callback(*args))
callback_p.def_abstract_eval(lambda *args, callback, **_: abstract_eval_fun(callback, *args))
def callback_lowering(ctx, *args, callback, callback_lowered):
try:
iter(abstract_eval_fun(callback, *ctx.avals_in))
except TypeError:
f = lambda *args: (callback_lowered(*args),)
else:
f = callback_lowered
result, keepalive = mlir.emit_python_callback(ctx.module_context.platform, f, args, ctx.avals_in, ctx.avals_out, False)
ctx.module_context.add_keepalive(keepalive)
return result
mlir.register_lowering(callback_p, mlir.cache_lowering(callback_lowering))
def bar(x):
x = callback_p.bind(x, callback=foo, callback_lowered=foo_lowered)
x = callback_p.bind(x, callback=foo, callback_lowered=foo_lowered)
x = callback_p.bind(x, callback=foo, callback_lowered=foo_lowered)
x = callback_p.bind(x, callback=foo, callback_lowered=foo_lowered)
return x
print(jax.make_jaxpr(bar)(1))
print(jax.jit(bar)(1))
print(jax.jit(bar).lower(1).compiler_ir()) # MHLO
print(jax.jit(bar).lower(1).compile().compiler_ir()[0].to_string()) # post-compilation HLO
traced
<class 'jax.interpreters.partial_eval.DynamicJaxprTracer'>
{ lambda ; a:i32[]. let
b:i32[] = callback[
callback=<function foo at 0x7f78a3d13790>
callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>
] a
c:i32[] = callback[
callback=<function foo at 0x7f78a3d13790>
callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>
] b
d:i32[] = callback[
callback=<function foo at 0x7f78a3d13790>
callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>
] c
e:i32[] = callback[
callback=<function foo at 0x7f78a3d13790>
callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>
] d
in (e,) }
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
5
module @jit_bar.1 {
func.func public @main(%arg0: tensor<i32>) -> tensor<i32> {
%0 = call @callback(%arg0) : (tensor<i32>) -> tensor<i32>
%1 = call @callback(%0) : (tensor<i32>) -> tensor<i32>
%2 = call @callback(%1) : (tensor<i32>) -> tensor<i32>
%3 = call @callback(%2) : (tensor<i32>) -> tensor<i32>
return %3 : tensor<i32>
}
func.func private @callback(%arg0: tensor<i32>) -> tensor<i32> {
%0 = mhlo.constant dense<94158890291952> : tensor<i64>
%1 = "mhlo.custom_call"(%0, %arg0) {api_version = 2 : i32, backend_config = "94158890291952", call_target_name = "xla_python_gpu_callback", called_computations = [], has_side_effect = false} : (tensor<i64>, tensor<i32>) -> tuple<tensor<i32>>
%2 = "mhlo.get_tuple_element"(%1) {index = 0 : i32} : (tuple<tensor<i32>>) -> tensor<i32>
return %2 : tensor<i32>
}
}
HloModule jit_bar.2, entry_computation_layout={(s32[])->s32[]}
ENTRY %main.26 (Arg_0.1: s32[]) -> s32[] {
%constant_0 = s64[] constant(94158882601360)
%Arg_0.1 = s32[] parameter(0)
%custom-call.0 = (s32[]) custom-call(s64[] %constant_0, s32[] %Arg_0.1), custom_call_target="xla_python_gpu_callback", api_version=API_VERSION_STATUS_RETURNING, metadata={op_name="jit(bar)/jit(main)/callback[callback=<function foo at 0x7f78a3d13790> callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>]" source_file="/home/jiacheng/***/sub.py" source_line=45}, backend_config="94158882601360"
%get-tuple-element.0 = s32[] get-tuple-element((s32[]) %custom-call.0), index=0, metadata={op_name="jit(bar)/jit(main)/callback[callback=<function foo at 0x7f78a3d13790> callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>]" source_file="/home/jiacheng/***/sub.py" source_line=45}
%custom-call.1 = (s32[]) custom-call(s64[] %constant_0, s32[] %get-tuple-element.0), custom_call_target="xla_python_gpu_callback", api_version=API_VERSION_STATUS_RETURNING, metadata={op_name="jit(bar)/jit(main)/callback[callback=<function foo at 0x7f78a3d13790> callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>]" source_file="/home/jiacheng/***/sub.py" source_line=45}, backend_config="94158882601360"
%get-tuple-element.1 = s32[] get-tuple-element((s32[]) %custom-call.1), index=0, metadata={op_name="jit(bar)/jit(main)/callback[callback=<function foo at 0x7f78a3d13790> callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>]" source_file="/home/jiacheng/***/sub.py" source_line=45}
%custom-call.2 = (s32[]) custom-call(s64[] %constant_0, s32[] %get-tuple-element.1), custom_call_target="xla_python_gpu_callback", api_version=API_VERSION_STATUS_RETURNING, metadata={op_name="jit(bar)/jit(main)/callback[callback=<function foo at 0x7f78a3d13790> callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>]" source_file="/home/jiacheng/***/sub.py" source_line=45}, backend_config="94158882601360"
%get-tuple-element.2 = s32[] get-tuple-element((s32[]) %custom-call.2), index=0, metadata={op_name="jit(bar)/jit(main)/callback[callback=<function foo at 0x7f78a3d13790> callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>]" source_file="/home/jiacheng/***/sub.py" source_line=45}
%custom-call.3 = (s32[]) custom-call(s64[] %constant_0, s32[] %get-tuple-element.2), custom_call_target="xla_python_gpu_callback", api_version=API_VERSION_STATUS_RETURNING, metadata={op_name="jit(bar)/jit(main)/callback[callback=<function foo at 0x7f78a3d13790> callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>]" source_file="/home/jiacheng/***/sub.py" source_line=45}, backend_config="94158882601360"
ROOT %get-tuple-element.3 = s32[] get-tuple-element((s32[]) %custom-call.3), index=0, metadata={op_name="jit(bar)/jit(main)/callback[callback=<function foo at 0x7f78a3d13790> callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>]" source_file="/home/jiacheng/***/sub.py" source_line=45}
}
BTW, host_callback
uses the CustomCall
implementation already by default on CPU (no outfeed
).
There is a new mechanism being developed to replace host_callback
, but it is not yet ready.
In any case, using host_callback
here seems more like a workaround than the best solution.
Closing in favour of #218, which adds equinox.internal.noinline
. This offers a heavily improved version of this.
- Much faster.
- Uses a custom primitive directly, instead of monkey-patching
host_callback.call
. - Offers the ability to only recompile only the part of a computation graph that has changed, without needing to recompile the entire graph. For example:
def abstract(x, y): return jnp.broadcast_arrays(x, y)[0] def f(x, y): print("Compiling f!") return x + y def g(x, y): print("Compiling g!") return x * y f = noinline(f, abstract) g = noinline(g, abstract) def call(fn, x, y): print("Compiling call!") return fn(x, y) call = eqx.filter_jit(call) call(f, 1, 1) # Compiling call! Compiling f! call(g, 1, 1) # Compiling g! [But does not recompile call!]