jax
jax copied to clipboard
jax_jit.cc support for Tracer: don't cache_miss
Working with complex jit-ed functions can lead to long compile times and hurt interactive workflows. One option is to jit only pieces of the function. This works great if only jit is the top level function transformation in all cases. If this is not the case, say when using other transforms, things dramatically slow down due to python overhead. For example:
@jax.jit
def fun(x):
return x * 2
def grad(y):
for i in range(10):
y = fun(y)
return y
grad = jax.grad(grad)(1.)
Upon investigation, jax's c++ jit codepath starts missing caches (cache_miss in profiler) and falls back to the python jit codepath. VLog errors show this. For example with BatchTracers, and VJPTracers. This is quite slow -- taking 33ms in my case where the actual computation takes 1.1ms.
I0603 13:10:37.163035 3754 jax_jit.cc:892] ComputeSignature failed: INVALID_ARGUMENT: Not supported: The C++ ToPyArgSignature only accepts Buffer/DeviceArray/ShardedDeviceArray, Numpy arrays scalars of supported types (see implementation), or Python scalars. Got type <class 'jax.interpreters.batching.BatchTracer'>
I0603 13:10:37.186771 3754 jax_jit.cc:892] ComputeSignature failed: INVALID_ARGUMENT: Not supported: The C++ ToPyArgSignature only accepts Buffer/DeviceArray/ShardedDeviceArray, Numpy arrays scalars of supported types (see implementation), or Python scalars. Got type <class 'jax.interpreters.ad.JVPTracer'>
Would it be possible to add paths, or somehow strip the tracer before executing the CPP codepath?
This is something we have talked about implementing (C++ jit for tracer values) for quite some time, but have never gotten to doing. It makes sense to do, we just need to do it!
Out of curiosity, what debug statement was used to get this output? jax.config.update("jax_explain_cache_misses", True) ?
It has been a few years, but something like TF_CPP_VMODULE=jax_jit=3 might work.