`EQX_ON_ERROR=breakpoint` and `diffeqsolve` causes tracer leakage ?
Hi Patrick, i'm getting weird behavior with HEAD and its use with EQX_ON_ERROR=breakpoint.
Here is a MWE :
import jax
import jax.numpy as jnp
import diffrax
with jax.checking_leaks():
ts = jnp.linspace(0.0, 1.0, 10)
ys = diffrax.diffeqsolve(
diffrax.ODETerm(lambda t, y, args: -y),
diffrax.Bosh3(),
t0=ts[0],
t1=ts[-1],
dt0=ts[1] - ts[0],
y0=jnp.ones((1, )),
saveat=diffrax.SaveAt(ts=ts),
)
Launching the file once is OK.
Then if I set the bash variable export EQX_ON_ERROR=breakpoint in another terminal and re-run the MWE, the error stack outputs leaked tracers (and EQX_ON_ERROR=breakpoint doesn't open a jax.debug.breakpoint where the error arises) :
Traceback (most recent call last):
File "/home/monsel/Desktop/sandbox_diffrax/mwe.py", line 10, in <module>
ys = diffrax.diffeqsolve(
File "/home/monsel/miniconda3/envs/sandbox_diffrax/lib/python3.10/site-packages/equinox/_jit.py", line 239, in __call__
return self._call(False, args, kwargs)
File "/home/monsel/miniconda3/envs/sandbox_diffrax/lib/python3.10/site-packages/equinox/_module.py", line 1093, in __call__
return self.__func__(self.__self__, *args, **kwargs)
File "/home/monsel/miniconda3/envs/sandbox_diffrax/lib/python3.10/site-packages/equinox/_jit.py", line 212, in _call
out = self._cached(dynamic_donate, dynamic_nodonate, static)
File "/home/monsel/miniconda3/envs/sandbox_diffrax/lib/python3.10/site-packages/equinox/_errors.py", line 187, in fixed_jit_impl
return jit_fun(*args2, **kwargs2)
File "/home/monsel/miniconda3/envs/sandbox_diffrax/lib/python3.10/contextlib.py", line 142, in __exit__
next(self.gen)
Exception: Leaked trace MainTrace(1,DynamicJaxprTrace). Leaked tracer(s):
Traced<ShapedArray(float32[1])>with<DynamicJaxprTrace(level=1/0)>
This DynamicJaxprTracer was created on line /home/monsel/Desktop/sandbox_diffrax/mwe.py:10 (<module>)
<DynamicJaxprTracer 139338190526784> is referred to by <function 139338169797072> (_allocate_output) closed-over variable y0
<function 139338169797072> is referred to by <list 139338138048256>[11]
<list 139338138048256> is referred to by <tuple 139338137653056>[1]
Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
This DynamicJaxprTracer was created on line /home/monsel/Desktop/sandbox_diffrax/mwe.py:10 (<module>)
<DynamicJaxprTracer 139338169959664> is referred to by <function 139338169797072> (_allocate_output) closed-over variable t0
<function 139338169797072> is referred to by <list 139338138048256>[11]
<list 139338138048256> is referred to by <tuple 139338137653056>[1]
Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
This DynamicJaxprTracer was created on line /home/monsel/Desktop/sandbox_diffrax/mwe.py:10 (<module>)
<DynamicJaxprTracer 139338169486448> is referred to by <function 139338169797072> (_allocate_output) closed-over variable direction
<function 139338169797072> is referred to by <list 139338138048256>[11]
<list 139338138048256> is referred to by <tuple 139338137653056>[1]
Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
This DynamicJaxprTracer was created on line /home/monsel/Desktop/sandbox_diffrax/mwe.py:10 (<module>)
<DynamicJaxprTracer 139338190525104> is referred to by <function 139338169246064> (_check_subsaveat_ts) closed-over variable t1
<function 139338169246064> is referred to by <list 139338138048256>[5]
<list 139338138048256> is referred to by <tuple 139338137653056>[1]
equinox==0.11.7
jax==0.4.34
jaxlib==0.4.34
jaxtyping==0.2.34
lineax==0.0.6
ml_dtypes==0.5.0
numpy==2.1.2
opt_einsum==3.4.0
optimistix==0.0.8
scipy==1.14.1
typeguard==2.13.3
typing_extensions==4.12.2
This is probably a variant of https://github.com/jax-ml/jax/issues/16732.
Equinox already has a workaround for the specific reported version above (when EQX_ON_ERROR=breakpoint is set then we monkey-patch jax.jit to conditionally disable it), but I believe it can also occur for some other JAX operations, like jax.custom_vjp.
I think what this really needs is someone to fix this in JAX itself, unfortunately.
Other than that, you can try setting JAX_DISABLE_JIT=1 and sidestep the issue that way.
Thanks for the clear explanation ! I'll give JAX_DISABLE_JIT=1 a try.
This bug is definitely misleading.