Vmapped jnp.interp throws error when debugging nan's with disabled jit
We've run into a strange error when using jnp.interp and debugging nans with jit disabled. I'm aware that these issues can sometimes arise and enabling only one of these flags is typically recommended. Nevertheless, it might be worth looking into in case it's an easy fix or a sign of something more serious going on.
MWE:
%env JAX_DEBUG_NANS=1
%env JAX_DISABLE_JIT=1
import diffrax
import jax.numpy as jnp
xp = jnp.array([0,1,2])
fp = jnp.array([0.0,1.0,0.0])
def vector_field(t, u, args):
return jnp.interp(u, xp, fp)
result = diffrax.diffeqsolve(diffrax.ODETerm(vector_field),
diffrax.Euler(),
t0=0.0,
t1=1.0,
dt0=0.01,
y0=jnp.array([0.5, 1.5]))
If I disable either or both flags this runs fine. If y0 is scalar or of length 1 it also runs fine. It also runs fine if I use adjoint=diffrax.ForwardMode() or adjoint=diffrax.DirectAdjoint. Furthermore, the following also runs fine with the flags on:
jnp.interp(jnp.array([0.5, 1.5]), xp, fp)
My theory is this has something to do with RecursiveCheckpointAdjoint but it doesn't seem to simply be that jnp.interp is not reverse-mode differentiable as I can calculate jacrev just fine.
Traceback
---------------------------------------------------------------------------
JaxStackTraceBeforeTransformation Traceback (most recent call last)
<frozen runpy> in ?()
--> 198 'Could not get source, probably due dynamically evaluated source code.'
<frozen runpy> in ?()
---> 88 'Could not get source, probably due dynamically evaluated source code.'
~/.virtualenvs/ergodic/lib/python3.13/site-packages/ipykernel_launcher.py in ?()
---> 18 """Entry point for launching an IPython kernel.
19
~/.virtualenvs/ergodic/lib/python3.13/site-packages/traitlets/config/application.py in ?()
-> 1075 app.start()
1076
~/.virtualenvs/ergodic/lib/python3.13/site-packages/ipykernel/kernelapp.py in ?()
--> 739 self.io_loop.start()
~/.virtualenvs/ergodic/lib/python3.13/site-packages/tornado/platform/asyncio.py in ?()
--> 205 self.asyncio_loop.run_forever()
/opt/homebrew/Cellar/[email protected]/3.13.0_1/Frameworks/Python.framework/Versions/3.13/lib/python3.13/asyncio/base_events.py in ?()
--> 679 self._run_once()
/opt/homebrew/Cellar/[email protected]/3.13.0_1/Frameworks/Python.framework/Versions/3.13/lib/python3.13/asyncio/base_events.py in ?()
-> 2027 handle._run()
/opt/homebrew/Cellar/[email protected]/3.13.0_1/Frameworks/Python.framework/Versions/3.13/lib/python3.13/asyncio/events.py in ?()
---> 89 self._context.run(self._callback, *self._args)
~/.virtualenvs/ergodic/lib/python3.13/site-packages/ipykernel/kernelbase.py in ?()
--> 545 await self.process_one()
~/.virtualenvs/ergodic/lib/python3.13/site-packages/ipykernel/kernelbase.py in ?()
--> 534 await dispatch(*args)
~/.virtualenvs/ergodic/lib/python3.13/site-packages/ipykernel/kernelbase.py in ?()
--> 437 await result
~/.virtualenvs/ergodic/lib/python3.13/site-packages/ipykernel/ipkernel.py in ?()
--> 362 await super().execute_request(stream, ident, parent)
~/.virtualenvs/ergodic/lib/python3.13/site-packages/ipykernel/kernelbase.py in ?()
--> 778 reply_content = await reply_content
~/.virtualenvs/ergodic/lib/python3.13/site-packages/ipykernel/ipkernel.py in ?()
--> 449 res = shell.run_cell(
~/.virtualenvs/ergodic/lib/python3.13/site-packages/ipykernel/zmqshell.py in ?()
--> 549 return super().run_cell(*args, **kwargs)
/var/folders/7l/46fqh6j56rv1n6xbrntn3bmm0000gn/T/ipykernel_3478/3052998300.py in ?()
---> 13 get_ipython().run_line_magic('env', 'JAX_DEBUG_NANS=1')
14 get_ipython().run_line_magic('env', 'JAX_DISABLE_JIT=1')
~/.virtualenvs/ergodic/lib/python3.13/site-packages/diffrax/_integrate.py in ?()
-> 1416 final_state, aux_stats = adjoint.loop(
1417 args=args,
~/.virtualenvs/ergodic/lib/python3.13/site-packages/diffrax/_adjoint.py in ?()
--> 299 final_state = self._loop(
300 terms=terms,
~/.virtualenvs/ergodic/lib/python3.13/site-packages/diffrax/_integrate.py in ?()
--> 638 final_state = outer_while_loop(
639 cond_fun, body_fun, init_state, max_steps=max_steps, buffers=_outer_buffers
/opt/homebrew/Cellar/[email protected]/3.13.0_1/Frameworks/Python.framework/Versions/3.13/lib/python3.13/contextlib.py in ?()
---> 85 return func(*args, **kwds)
~/.virtualenvs/ergodic/lib/python3.13/site-packages/equinox/internal/_loop/loop.py in ?()
--> 107 return checkpointed_while_loop(
~/.virtualenvs/ergodic/lib/python3.13/site-packages/equinox/internal/_loop/checkpointed.py in ?()
--> 247 body_fun_ = filter_closure_convert(body_fun_, init_val_)
248 vjp_arg = (init_val_, body_fun_)
~/.virtualenvs/ergodic/lib/python3.13/site-packages/equinox/internal/_loop/common.py in ?()
--> 471 buffer_val2 = body_fun(buffer_val)
472 # Needed to work with `disable_jit`, as then we lose the automatic
~/.virtualenvs/ergodic/lib/python3.13/site-packages/diffrax/_integrate.py in ?()
--> 635 new_state, _, _ = body_fun_aux(state)
~/.virtualenvs/ergodic/lib/python3.13/site-packages/diffrax/_integrate.py in ?()
--> 349 (y, y_error, dense_info, solver_state, solver_result) = solver.step(
350 terms,
~/.virtualenvs/ergodic/lib/python3.13/site-packages/diffrax/_solver/euler.py in ?()
---> 60 y1 = (y0**ω + terms.vf_prod(t0, y0, args, control) ** ω).ω
~/.virtualenvs/ergodic/lib/python3.13/site-packages/diffrax/_term.py in ?()
--> 756 return self.term.vf_prod(t, y, args, control)
~/.virtualenvs/ergodic/lib/python3.13/site-packages/diffrax/_term.py in ?()
--> 157 return self.prod(self.vf(t, y, args), control)
~/.virtualenvs/ergodic/lib/python3.13/site-packages/diffrax/_term.py in ?()
--> 194 out = self.vector_field(t, y, args)
195 if jtu.tree_structure(out) != jtu.tree_structure(y):
/var/folders/7l/46fqh6j56rv1n6xbrntn3bmm0000gn/T/ipykernel_3478/3052998300.py in ?()
---> 11 return jnp.interp(u, xp, fp)
~/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/numpy/lax_numpy.py in ?()
-> 2752 return jitted_interp(x, xp, fp, left, right, period)
2753
~/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/numpy/lax_numpy.py in ?()
-> 2668 i = clip(searchsorted(xp_arr, x_arr, side='right'), 1, len(xp_arr) - 1)
2669 df = fp_arr[i] - fp_arr[i - 1]
~/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/numpy/lax_numpy.py in ?()
-> 9992 return impl(a, v, side, dtype) # type: ignore
9993
~/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/numpy/vectorize.py in ?()
--> 346 result = vectorized_func(*squeezed_args)
~/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/numpy/vectorize.py in ?()
--> 144 out = func(*args)
~/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/numpy/vectorize.py in ?()
--> 187 return func(*args, **kwargs, **static_kwargs)
~/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/numpy/lax_numpy.py in ?()
-> 9881 carry, _ = lax.scan(body_fun, init, (), length=n_levels,
9882 unroll=n_levels if unrolled else 1)
~/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/numpy/lax_numpy.py in ?()
-> 9877 go_left = op(query, sorted_arr[mid])
JaxStackTraceBeforeTransformation: FloatingPointError: invalid value (nan) encountered in broadcast_in_dim
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
FloatingPointError Traceback (most recent call last)
Cell In[1], line 13
10 def vector_field(t, u, args):
11 return jnp.interp(u, xp, fp)
---> 13 result = diffrax.diffeqsolve(diffrax.ODETerm(vector_field),
14 diffrax.Euler(),
15 t0=0.0,
16 t1=1.0,
17 dt0=0.01,
18 y0=jnp.array([0.5, 1.5]))
[... skipping hidden 2 frame]
File ~/.virtualenvs/ergodic/lib/python3.13/site-packages/equinox/_jit.py:55, in _filter_jit_cache.<locals>.fun_wrapped(dynamic_donate, dynamic_nodonate, static)
53 *args, dummy_arg = (first_arg,) + rest_args
54 assert dummy_arg is None
---> 55 out = fun(*args, **kwargs)
56 dynamic_out, static_out = partition(out, is_array)
57 marker = jnp.array(0)
File ~/.virtualenvs/ergodic/lib/python3.13/site-packages/diffrax/_integrate.py:1416, in diffeqsolve(terms, solver, t0, t1, dt0, y0, args, saveat, stepsize_controller, adjoint, event, max_steps, throw, progress_meter, solver_state, controller_state, made_jump, discrete_terminating_event)
1389 init_state = State(
1390 y=y0,
1391 tprev=tprev,
(...) 1409 event_mask=event_mask,
1410 )
1412 #
1413 # Main loop
1414 #
-> 1416 final_state, aux_stats = adjoint.loop(
1417 args=args,
1418 terms=terms,
1419 solver=solver,
1420 stepsize_controller=stepsize_controller,
1421 event=event,
1422 saveat=saveat,
1423 t0=t0,
1424 t1=t1,
1425 dt0=dt0,
1426 max_steps=max_steps,
1427 init_state=init_state,
1428 throw=throw,
1429 passed_solver_state=passed_solver_state,
1430 passed_controller_state=passed_controller_state,
1431 progress_meter=progress_meter,
1432 )
1434 #
1435 # Finish up
1436 #
1438 progress_meter.close(final_state.progress_meter_state)
[... skipping hidden 1 frame]
File ~/.virtualenvs/ergodic/lib/python3.13/site-packages/diffrax/_adjoint.py:299, in RecursiveCheckpointAdjoint.loop(***failed resolving arguments***)
295 outer_while_loop = ft.partial(
296 _outer_loop, kind="checkpointed", checkpoints=self.checkpoints
297 )
298 msg = None
--> 299 final_state = self._loop(
300 terms=terms,
301 saveat=saveat,
302 init_state=init_state,
303 max_steps=max_steps,
304 inner_while_loop=inner_while_loop,
305 outer_while_loop=outer_while_loop,
306 **kwargs,
307 )
308 if msg is not None:
309 final_state = eqxi.nondifferentiable_backward(
310 final_state, msg=msg, symbolic=True
311 )
File ~/.virtualenvs/ergodic/lib/python3.13/site-packages/diffrax/_integrate.py:638, in loop(solver, stepsize_controller, event, saveat, t0, t1, dt0, max_steps, terms, args, init_state, inner_while_loop, outer_while_loop, progress_meter)
635 new_state, _, _ = body_fun_aux(state)
636 return new_state
--> 638 final_state = outer_while_loop(
639 cond_fun, body_fun, init_state, max_steps=max_steps, buffers=_outer_buffers
640 )
641 result = final_state.result
643 if event is None or event.root_finder is None:
File /opt/homebrew/Cellar/[email protected]/3.13.0_1/Frameworks/Python.framework/Versions/3.13/lib/python3.13/contextlib.py:85, in ContextDecorator.__call__.<locals>.inner(*args, **kwds)
82 @wraps(func)
83 def inner(*args, **kwds):
84 with self._recreate_cm():
---> 85 return func(*args, **kwds)
File ~/.virtualenvs/ergodic/lib/python3.13/site-packages/equinox/internal/_loop/loop.py:107, in while_loop(***failed resolving arguments***)
105 elif kind == "checkpointed":
106 del kind, base
--> 107 return checkpointed_while_loop(
108 cond_fun,
109 body_fun,
110 init_val,
111 max_steps=max_steps,
112 buffers=buffers,
113 checkpoints=checkpoints,
114 )
115 elif kind == "bounded":
116 del kind, checkpoints
File ~/.virtualenvs/ergodic/lib/python3.13/site-packages/equinox/internal/_loop/checkpointed.py:249, in checkpointed_while_loop(***failed resolving arguments***)
247 body_fun_ = filter_closure_convert(body_fun_, init_val_)
248 vjp_arg = (init_val_, body_fun_)
--> 249 final_val_ = _checkpointed_while_loop(
250 vjp_arg, cond_fun_, checkpoints, buffers_, max_steps
251 )
252 _, _, _, final_val = _stop_gradient_on_unperturbed(init_val_, final_val_, body_fun_)
253 return final_val
File ~/.virtualenvs/ergodic/lib/python3.13/site-packages/equinox/_ad.py:1107, in filter_custom_vjp.__call__(self, vjp_arg, *args, **kwargs)
1103 array_args_kwargs, nonarray_args_kwargs = partition((args, kwargs), is_array)
1104 array_args_kwargs = nondifferentiable(
1105 array_args_kwargs, name="`*args` and `**kwargs` to `filter_custom_vjp`"
1106 )
-> 1107 out = self.fn_wrapped(
1108 nonarray_vjp_arg,
1109 nonarray_args_kwargs,
1110 diff_array_vjp_arg,
1111 nondiff_array_vjp_arg,
1112 array_args_kwargs,
1113 )
1114 diff_array_out, nondiff_array_out, nonarray_out = out
1115 return combine(diff_array_out, nondiff_array_out, nonarray_out.value)
[... skipping hidden 11 frame]
File ~/.virtualenvs/ergodic/lib/python3.13/site-packages/equinox/internal/_loop/checkpointed.py:271, in _checkpointed_while_loop(***failed resolving arguments***)
268 while_loop = jax.named_call(lax.while_loop, name="checkpointed-no-vjp")
269 # Hashable wrapper; JAX issue #13554 and
270 # https://github.com/patrick-kidger/equinox/issues/768
--> 271 return while_loop(lambda x: cond_fun(x), lambda x: body_fun(x), init_val)
File /opt/homebrew/Cellar/[email protected]/3.13.0_1/Frameworks/Python.framework/Versions/3.13/lib/python3.13/contextlib.py:85, in ContextDecorator.__call__.<locals>.inner(*args, **kwds)
82 @wraps(func)
83 def inner(*args, **kwds):
84 with self._recreate_cm():
---> 85 return func(*args, **kwds)
[... skipping hidden 2 frame]
File ~/.virtualenvs/ergodic/lib/python3.13/site-packages/equinox/internal/_loop/checkpointed.py:271, in _checkpointed_while_loop.<locals>.<lambda>(x)
268 while_loop = jax.named_call(lax.while_loop, name="checkpointed-no-vjp")
269 # Hashable wrapper; JAX issue #13554 and
270 # https://github.com/patrick-kidger/equinox/issues/768
--> 271 return while_loop(lambda x: cond_fun(x), lambda x: body_fun(x), init_val)
[... skipping hidden 8 frame]
File ~/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/api.py:121, in _nan_check_posthook(fun, args, kwargs, output)
119 f = fun._fun
120 if getattr(f, '_apply_primitive', False):
--> 121 raise FloatingPointError(f"invalid value ({e.ty}) encountered in {f.__qualname__}") from None
122 # compiled_fun can only raise in this case
123 dispatch.maybe_recursive_nan_check(e, f, args, kwargs)
FloatingPointError: invalid value (nan) encountered in broadcast_in_dim
@aidancrilly who first noticed this issue
With which versions of JAX, equinox and diffrax do you get this error?
jax version: 0.6.0 equinox version: 0.12.1 diffrax version: 0.7.0
I can reproduce it with this JAX-only MWE:
import os
os.environ["JAX_DEBUG_NANS"] = "1"
os.environ["JAX_DISABLE_JIT"] = "1"
import jax
import jax.core
import jax.numpy as jnp
t = jnp.array([0.5, 1.5])
xp = jnp.array([0,1,2])
fp = jnp.array([0.0,1.0,0.0])
jaxpr = jax.make_jaxpr(lambda t: jnp.interp(t, xp, fp))(t)
jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, t)
It looks like the offending culprit is this line here:
https://github.com/jax-ml/jax/blob/8c987bfb0afd36710038a96430ae95f473503e08/jax/_src/lax/lax.py#L7923
Whilst it is wrapped in a debug_nans(False) during the initial tracing into a jaxpr, this context is no longer present during the subsequent evaluation.
I'm not immediately sure what would be a reasonable workaround from your perspective, though.
Thanks, do you know why it depends on adjoint?
Simply that one of the adjoint methods includes code of the above sort whilst the others do not, I think.
(I'm tagging this as a question rather than a bug, as sadly I don't think there's anything wrong occuring in Diffrax here / anything we can do to work around it in Diffrax.)
Thanks Patrick, I did some digging around and noticed there is some commentary about this in jax._src.core.jaxpr_as_fun. I'm guessing there's no clever way we could use jaxpr_as_fun with the recursive checkpoint adjoint? It sounds like there may be ambitions to add contexts to jaxpr's one day which should solve the problem. Do you think it's worth raising this as an issue on the jax side? My gut feeling is that the fix is not at all trivial there either.
For posteriority, if anyone encounters something similar in the future, my current ways to do useful debugging when encountering these issues are to:
- reduce
max_stepsconsiderably and useDirectAdjoint(to avoid memory issues), or - just manually run a single step in isolation without
diffrax. This enabled me to debug a nasty reverse passnanI wasn't able to pin down otherwise. (Unfortunately, debugging nan's without disabling jit typically gives very opaque error message in my limited experience.)
From what I've seen I suspect this is probably a wontfix from the JAX side, although I don't wish to speak on behalf of their team.