equinox
equinox copied to clipboard
Bounded while loop
While working with some DirectAdjoint modifications, I noticed repeated empty assertion errors stemming from bounded while loops,
test/test_adjoint.py:333: in _run_inexact
return _run(eqx.combine(inexact, static), saveat, adjoint)
test/test_adjoint.py:291: in _run
ys = diffrax.diffeqsolve(
diffrax/_integrate.py:1462: in diffeqsolve
final_state, aux_stats = adjoint.loop(
diffrax/_adjoint.py:405: in loop
final_state = self._loop(
diffrax/_integrate.py:641: in loop
final_state = outer_while_loop(
../../miniforge3/envs/dev_diffrax/lib/python3.10/contextlib.py:79: in inner
return func(*args, **kwds)
../../miniforge3/envs/dev_diffrax/lib/python3.10/site-packages/equinox/internal/_loop/loop.py:119: in while_loop
return bounded_while_loop(
../../miniforge3/envs/dev_diffrax/lib/python3.10/site-packages/equinox/internal/_loop/bounded.py:59: in bounded_while_loop
_, _, _, val = _while_loop(cond_fun_, body_fun_, init_val_, rounded_max_steps, base)
../../miniforge3/envs/dev_diffrax/lib/python3.10/site-packages/equinox/internal/_loop/bounded.py:78: in _while_loop
return lax.scan(scan_fn, val, xs=None, length=base)[0]
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
> return lax.cond(cond_fun(val), call, lambda x: x, val), None
E jax._src.source_info_util.JaxStackTraceBeforeTransformation: AssertionError
I was curious if you had ever seen this before? I will work to get a MVC in the meantime.
Maybe this belongs in diffrax, but since the core code is equinox bounded while loop (and DirectAdjoint is a pretty thin layer over them) I put it here
Hmmm nope, this one isn't familiar. From the traceback -- a totally innocuous-looking line -- this looks like it might be coming from JAX internals, probably an assert statement inside one of the cond_p rules? I'll take a look at the MWE once you have it. :)