custom_jvp leaks tracers if they're marked as nondiff_argnum
import jax
import jax.numpy as jnp
from functools import partial
jax.config.update("jax_check_tracer_leaks", True)
@partial(jax.custom_jvp, nondiff_argnums=(1,))
def f(x, indices):
return x[indices]
@f.defjvp
def f_jvp(indices, primals, tangents):
x, = primals
x_dot, = tangents
return f(x, indices), x_dot[indices]
x = jnp.arange(10.0)
indices = jnp.array([1, 3, 5])
jax.jit(jax.jacobian(f))(x, indices)
---------------------------------------------------------------------------
Exception Traceback (most recent call last)
[<ipython-input-7-6ccf5735404c>](https://localhost:8080/#) in <cell line: 20>()
18 indices = jnp.array([1, 3, 5])
19
---> 20 jax.jit(jax.jacobian(f))(x, indices)
[... skipping hidden 17 frame]
1 frames
[/usr/local/lib/python3.10/dist-packages/jax/_src/core.py](https://localhost:8080/#) in new_main(trace_type, dynamic, **payload)
1201 if t() is not None:
1202 leaked_tracers = maybe_find_leaked_tracers(t())
-> 1203 if leaked_tracers: raise leaked_tracer_error("trace", t(), leaked_tracers)
1204
1205 @contextmanager
Exception: Leaked trace MainTrace(2,JaxprTrace). Leaked tracer(s):
Traced<ShapedArray(float32[10]):JaxprTrace(level=2/0)>
<JaxprTracer 132784156284128> is referred to by <list 132784156610688>[0]
<list 132784156610688> is referred to by <frame 132784230502464>
<frame 132784230502464> is referred to by <frame 97010772184064>
<frame 97010772184064> is referred to by <generator 132784155884816>
Traced<ShapedArray(float32[3]):JaxprTrace(level=2/0)>
<JaxprTracer 132784156291008> is referred to by <list 132784155950080>[0]
<list 132784155950080> is referred to by <frame 132784230502464>
<frame 132784230502464> is referred to by <list 132784232647808>[3]
<list 132784232647808> is referred to by <FramesList 132785276236272>._frames
<FramesList 132785276236272> is referred to by <frame 97010772184064>
<frame 97010772184064> is referred to by <generator 132784155884816>
Traced<ShapedArray(float32[3]):JaxprTrace(level=2/0)>
<JaxprTracer 132784156290608> is referred to by <tuple 132784231897008>[0]
<tuple 132784231897008> is referred to by <JaxprEqnRecipe 132784156756448>[1]
<JaxprEqnRecipe 132784156756448> is referred to by <JaxprTracer 132784156291008>
Traced<ShapedArray(int32[3,1]):JaxprTrace(level=2/0)>
<JaxprTracer 132784156287888> is referred to by <tuple 132784155944192>[1]
<tuple 132784155944192> is referred to by <JaxprEqnRecipe 132784156756560>[1]
<JaxprEqnRecipe 132784156756560> is referred to by <JaxprTracer 132784156290608>
<JaxprTracer 132784156290608> is referred to by <tuple 132784231897008>[0]
<tuple 132784231897008> is referred to by <JaxprEqnRecipe 132784156756448>[1]
<JaxprEqnRecipe 132784156756448> is referred to by <JaxprTracer 132784156291008>
The problem is here that indices is a traced array, but is marked as a nondiff_argnum. This is user error, but we should fail with a more informative error.
I think this is actually a bug (although @mattjj and @froystig will know better, of course!). If we update f_jvp to:
@f.defjvp
def f_jvp(indices, primals, tangents):
x, = primals
x_dot, = tangents
return x[indicies], x_dot[indices]
everything works as it should without any leaked tracers. I haven't had a chance to dig into why this is, but there's a lot of subtlety around handling these recursive patterns. Regardless, I don't think that we do need to require that "nondiff" args be "static".
OK maybe I take that back! It's clear from this error message:
https://github.com/google/jax/blob/82d3cfb3c6f88321f0b29b4cc41134a464de82c2/jax/_src/custom_derivatives.py#L656-L662
that custom VJP at least has this requirement that nondiff args cannot be tracers. I think it also follows that they can't be tracers because they're handled by baking them into the functions, rather than binding them. I guess from my example above, there are some cases where things might still work, but they shouldn't be expected to.
Perhaps it would be reasonable to add the tracer check from custom_vjp to custom_jvp?
Edited to add: These are two relevant tests:
https://github.com/google/jax/blob/82d3cfb3c6f88321f0b29b4cc41134a464de82c2/tests/api_test.py#L7262-L7303
and it looks like we currently expect the custom_jvp to work with some tracers (BatchTracer) but not others. Reading the PR adding that comment https://github.com/google/jax/pull/14263, I think perhaps it makes sense to add that check to custom_jvp, but perhaps there are people depending on the current more relaxed behavior.
I think I ran into this issue from some jax builtins use custom_jvp, minimal example is below. @dfm I would have naively agreed with your first interpretation that it seems like there is nothing conceptually intrinsic about nondiff arguments that should cause them to be static.
from jax import numpy as jnp
import jax
@jax.jit
def f(A, rtol):
return jnp.linalg.pinv(A, rtol)
_A = f(jnp.eye(3), 0.001)
print(f"{_A=}")
Traceback (most recent call last):
File "/home/cdodd/proj/test.py", line 8, in <module>
_A = f(jnp.eye(3), 0.001)
^^^^^^^^^^^^^^^^^^^^
File "/opt/python/3.11.0/gcc-system/lib/python3.11/contextlib.py", line 144, in __exit__
next(self.gen)
^^^^^^^^^^^^^^^^^^
Exception: Leaked trace DynamicJaxprTrace. Leaked tracer(s):
Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>
The error occurred while tracing the function f at /home/cdodd/proj/test.py:4 for jit. This concrete value was not available in Python because it depends on the value of the argument rtol.
<DynamicJaxprTracer 139925252380672> is referred to by <Unhashable 139925252235440>
<Unhashable 139925252235440> is referred to by <tuple 139925450911040>[0]
<tuple 139925450911040> is referred to by <tuple 139927089822496>[0]
<tuple 139927089822496> is referred to by <tuple 139925252395520>[1]
<tuple 139925252395520> is referred to by <tuple 139925252395840>[1]
<tuple 139925252395840> is referred to by <WrappedFun 139925451088576>
<WrappedFun 139925451088576> is referred to by <function 139925451687712> (jvp_jaxpr_thunk) closed-over variable jvp
<function 139925451687712> is referred to by <function 139925451688672> (memoized) closed-over variable fn
<function 139925451688672> is referred to by <dict 139925252426496>['jvp_jaxpr_thunk']
<dict 139925252426496> is referred to by <JaxprEqn 139925451094912>
<JaxprEqn 139925451094912> is referred to by <list 139925451708160>[1]
<list 139925451708160> is referred to by <Jaxpr 139925451094432>
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
It’s true that “non differentiable argument” need not imply “static argument”. But this feature is actually for declaring static arguments, and the name was chosen badly.
We should probably:
- Change the name to static_argnums (except keep the old one working for back compatibility);
- add the check to custom_jvp;
- add a new convenience feature for array like (ie tracery) values that we don’t want involved in differentiation.
The third has limited upside because you can already get this behavior today by passing in ordinary arguments and just ignoring their corresponding tangents.
That makes sense. I may open a separate issue for the case like 'pinv', it seems like rtol shouldn't need to be static.