diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

`EQX_ON_ERROR=breakpoint` and `diffeqsolve` causes tracer leakage ?

Open thibmonsel opened this issue 1 year ago • 2 comments

Hi Patrick, i'm getting weird behavior with HEAD and its use with EQX_ON_ERROR=breakpoint. Here is a MWE :


import jax
import jax.numpy as jnp
import diffrax

with jax.checking_leaks():

    ts = jnp.linspace(0.0, 1.0, 10)
    ys = diffrax.diffeqsolve(
        diffrax.ODETerm(lambda t, y, args: -y),
        diffrax.Bosh3(),
        t0=ts[0],
        t1=ts[-1],
        dt0=ts[1] - ts[0],
        y0=jnp.ones((1, )),
        saveat=diffrax.SaveAt(ts=ts),
    )

Launching the file once is OK. Then if I set the bash variable export EQX_ON_ERROR=breakpoint in another terminal and re-run the MWE, the error stack outputs leaked tracers (and EQX_ON_ERROR=breakpoint doesn't open a jax.debug.breakpoint where the error arises) :

Traceback (most recent call last):
  File "/home/monsel/Desktop/sandbox_diffrax/mwe.py", line 10, in <module>
    ys = diffrax.diffeqsolve(
  File "/home/monsel/miniconda3/envs/sandbox_diffrax/lib/python3.10/site-packages/equinox/_jit.py", line 239, in __call__
    return self._call(False, args, kwargs)
  File "/home/monsel/miniconda3/envs/sandbox_diffrax/lib/python3.10/site-packages/equinox/_module.py", line 1093, in __call__
    return self.__func__(self.__self__, *args, **kwargs)
  File "/home/monsel/miniconda3/envs/sandbox_diffrax/lib/python3.10/site-packages/equinox/_jit.py", line 212, in _call
    out = self._cached(dynamic_donate, dynamic_nodonate, static)
  File "/home/monsel/miniconda3/envs/sandbox_diffrax/lib/python3.10/site-packages/equinox/_errors.py", line 187, in fixed_jit_impl
    return jit_fun(*args2, **kwargs2)
  File "/home/monsel/miniconda3/envs/sandbox_diffrax/lib/python3.10/contextlib.py", line 142, in __exit__
    next(self.gen)
Exception: Leaked trace MainTrace(1,DynamicJaxprTrace). Leaked tracer(s):

Traced<ShapedArray(float32[1])>with<DynamicJaxprTrace(level=1/0)>
This DynamicJaxprTracer was created on line /home/monsel/Desktop/sandbox_diffrax/mwe.py:10 (<module>)
<DynamicJaxprTracer 139338190526784> is referred to by <function 139338169797072> (_allocate_output) closed-over variable y0
<function 139338169797072> is referred to by <list 139338138048256>[11]
<list 139338138048256> is referred to by <tuple 139338137653056>[1]

Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
This DynamicJaxprTracer was created on line /home/monsel/Desktop/sandbox_diffrax/mwe.py:10 (<module>)
<DynamicJaxprTracer 139338169959664> is referred to by <function 139338169797072> (_allocate_output) closed-over variable t0
<function 139338169797072> is referred to by <list 139338138048256>[11]
<list 139338138048256> is referred to by <tuple 139338137653056>[1]

Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
This DynamicJaxprTracer was created on line /home/monsel/Desktop/sandbox_diffrax/mwe.py:10 (<module>)
<DynamicJaxprTracer 139338169486448> is referred to by <function 139338169797072> (_allocate_output) closed-over variable direction
<function 139338169797072> is referred to by <list 139338138048256>[11]
<list 139338138048256> is referred to by <tuple 139338137653056>[1]

Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
This DynamicJaxprTracer was created on line /home/monsel/Desktop/sandbox_diffrax/mwe.py:10 (<module>)
<DynamicJaxprTracer 139338190525104> is referred to by <function 139338169246064> (_check_subsaveat_ts) closed-over variable t1
<function 139338169246064> is referred to by <list 139338138048256>[5]
<list 139338138048256> is referred to by <tuple 139338137653056>[1]

equinox==0.11.7
jax==0.4.34
jaxlib==0.4.34
jaxtyping==0.2.34
lineax==0.0.6
ml_dtypes==0.5.0
numpy==2.1.2
opt_einsum==3.4.0
optimistix==0.0.8
scipy==1.14.1
typeguard==2.13.3
typing_extensions==4.12.2

thibmonsel avatar Oct 10 '24 16:10 thibmonsel

This is probably a variant of https://github.com/jax-ml/jax/issues/16732.

Equinox already has a workaround for the specific reported version above (when EQX_ON_ERROR=breakpoint is set then we monkey-patch jax.jit to conditionally disable it), but I believe it can also occur for some other JAX operations, like jax.custom_vjp.

I think what this really needs is someone to fix this in JAX itself, unfortunately.

Other than that, you can try setting JAX_DISABLE_JIT=1 and sidestep the issue that way.

patrick-kidger avatar Oct 10 '24 17:10 patrick-kidger

Thanks for the clear explanation ! I'll give JAX_DISABLE_JIT=1 a try.
This bug is definitely misleading.

thibmonsel avatar Oct 10 '24 17:10 thibmonsel