equinox icon indicating copy to clipboard operation
equinox copied to clipboard

Bounded while loop

Open lockwo opened this issue 1 year ago • 2 comments

While working with some DirectAdjoint modifications, I noticed repeated empty assertion errors stemming from bounded while loops,

test/test_adjoint.py:333: in _run_inexact
    return _run(eqx.combine(inexact, static), saveat, adjoint)
test/test_adjoint.py:291: in _run
    ys = diffrax.diffeqsolve(
diffrax/_integrate.py:1462: in diffeqsolve
    final_state, aux_stats = adjoint.loop(
diffrax/_adjoint.py:405: in loop
    final_state = self._loop(
diffrax/_integrate.py:641: in loop
    final_state = outer_while_loop(
../../miniforge3/envs/dev_diffrax/lib/python3.10/contextlib.py:79: in inner
    return func(*args, **kwds)
../../miniforge3/envs/dev_diffrax/lib/python3.10/site-packages/equinox/internal/_loop/loop.py:119: in while_loop
    return bounded_while_loop(
../../miniforge3/envs/dev_diffrax/lib/python3.10/site-packages/equinox/internal/_loop/bounded.py:59: in bounded_while_loop
    _, _, _, val = _while_loop(cond_fun_, body_fun_, init_val_, rounded_max_steps, base)
../../miniforge3/envs/dev_diffrax/lib/python3.10/site-packages/equinox/internal/_loop/bounded.py:78: in _while_loop
    return lax.scan(scan_fn, val, xs=None, length=base)[0]
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

>   return lax.cond(cond_fun(val), call, lambda x: x, val), None
E   jax._src.source_info_util.JaxStackTraceBeforeTransformation: AssertionError

I was curious if you had ever seen this before? I will work to get a MVC in the meantime.

lockwo avatar Jan 04 '25 23:01 lockwo

Maybe this belongs in diffrax, but since the core code is equinox bounded while loop (and DirectAdjoint is a pretty thin layer over them) I put it here

lockwo avatar Jan 04 '25 23:01 lockwo

Hmmm nope, this one isn't familiar. From the traceback -- a totally innocuous-looking line -- this looks like it might be coming from JAX internals, probably an assert statement inside one of the cond_p rules? I'll take a look at the MWE once you have it. :)

patrick-kidger avatar Jan 05 '25 08:01 patrick-kidger