diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

Vmapped jnp.interp throws error when debugging nan's with disabled jit

Open jpbrodrick89 opened this issue 8 months ago • 7 comments

We've run into a strange error when using jnp.interp and debugging nans with jit disabled. I'm aware that these issues can sometimes arise and enabling only one of these flags is typically recommended. Nevertheless, it might be worth looking into in case it's an easy fix or a sign of something more serious going on.

MWE:

%env JAX_DEBUG_NANS=1
%env JAX_DISABLE_JIT=1
import diffrax
import jax.numpy as jnp

xp = jnp.array([0,1,2])
fp = jnp.array([0.0,1.0,0.0])


def vector_field(t, u, args):
    return jnp.interp(u, xp, fp)

result = diffrax.diffeqsolve(diffrax.ODETerm(vector_field),
                            diffrax.Euler(),
                            t0=0.0,
                            t1=1.0,
                            dt0=0.01,
                            y0=jnp.array([0.5, 1.5]))

If I disable either or both flags this runs fine. If y0 is scalar or of length 1 it also runs fine. It also runs fine if I use adjoint=diffrax.ForwardMode() or adjoint=diffrax.DirectAdjoint. Furthermore, the following also runs fine with the flags on:

jnp.interp(jnp.array([0.5, 1.5]), xp, fp)

My theory is this has something to do with RecursiveCheckpointAdjoint but it doesn't seem to simply be that jnp.interp is not reverse-mode differentiable as I can calculate jacrev just fine.

Traceback
---------------------------------------------------------------------------
JaxStackTraceBeforeTransformation         Traceback (most recent call last)
<frozen runpy> in ?()
--> 198 'Could not get source, probably due dynamically evaluated source code.'

<frozen runpy> in ?()
---> 88 'Could not get source, probably due dynamically evaluated source code.'

~/.virtualenvs/ergodic/lib/python3.13/site-packages/ipykernel_launcher.py in ?()
---> 18 """Entry point for launching an IPython kernel.
     19 

~/.virtualenvs/ergodic/lib/python3.13/site-packages/traitlets/config/application.py in ?()
-> 1075         app.start()
   1076 

~/.virtualenvs/ergodic/lib/python3.13/site-packages/ipykernel/kernelapp.py in ?()
--> 739                 self.io_loop.start()

~/.virtualenvs/ergodic/lib/python3.13/site-packages/tornado/platform/asyncio.py in ?()
--> 205         self.asyncio_loop.run_forever()

/opt/homebrew/Cellar/[email protected]/3.13.0_1/Frameworks/Python.framework/Versions/3.13/lib/python3.13/asyncio/base_events.py in ?()
--> 679                 self._run_once()

/opt/homebrew/Cellar/[email protected]/3.13.0_1/Frameworks/Python.framework/Versions/3.13/lib/python3.13/asyncio/base_events.py in ?()
-> 2027                 handle._run()

/opt/homebrew/Cellar/[email protected]/3.13.0_1/Frameworks/Python.framework/Versions/3.13/lib/python3.13/asyncio/events.py in ?()
---> 89             self._context.run(self._callback, *self._args)

~/.virtualenvs/ergodic/lib/python3.13/site-packages/ipykernel/kernelbase.py in ?()
--> 545                 await self.process_one()

~/.virtualenvs/ergodic/lib/python3.13/site-packages/ipykernel/kernelbase.py in ?()
--> 534         await dispatch(*args)

~/.virtualenvs/ergodic/lib/python3.13/site-packages/ipykernel/kernelbase.py in ?()
--> 437                     await result

~/.virtualenvs/ergodic/lib/python3.13/site-packages/ipykernel/ipkernel.py in ?()
--> 362         await super().execute_request(stream, ident, parent)

~/.virtualenvs/ergodic/lib/python3.13/site-packages/ipykernel/kernelbase.py in ?()
--> 778             reply_content = await reply_content

~/.virtualenvs/ergodic/lib/python3.13/site-packages/ipykernel/ipkernel.py in ?()
--> 449                     res = shell.run_cell(

~/.virtualenvs/ergodic/lib/python3.13/site-packages/ipykernel/zmqshell.py in ?()
--> 549         return super().run_cell(*args, **kwargs)

/var/folders/7l/46fqh6j56rv1n6xbrntn3bmm0000gn/T/ipykernel_3478/3052998300.py in ?()
---> 13 get_ipython().run_line_magic('env', 'JAX_DEBUG_NANS=1')
     14 get_ipython().run_line_magic('env', 'JAX_DISABLE_JIT=1')

~/.virtualenvs/ergodic/lib/python3.13/site-packages/diffrax/_integrate.py in ?()
-> 1416     final_state, aux_stats = adjoint.loop(
   1417         args=args,

~/.virtualenvs/ergodic/lib/python3.13/site-packages/diffrax/_adjoint.py in ?()
--> 299         final_state = self._loop(
    300             terms=terms,

~/.virtualenvs/ergodic/lib/python3.13/site-packages/diffrax/_integrate.py in ?()
--> 638     final_state = outer_while_loop(
    639         cond_fun, body_fun, init_state, max_steps=max_steps, buffers=_outer_buffers

/opt/homebrew/Cellar/[email protected]/3.13.0_1/Frameworks/Python.framework/Versions/3.13/lib/python3.13/contextlib.py in ?()
---> 85                 return func(*args, **kwds)

~/.virtualenvs/ergodic/lib/python3.13/site-packages/equinox/internal/_loop/loop.py in ?()
--> 107         return checkpointed_while_loop(

~/.virtualenvs/ergodic/lib/python3.13/site-packages/equinox/internal/_loop/checkpointed.py in ?()
--> 247     body_fun_ = filter_closure_convert(body_fun_, init_val_)
    248     vjp_arg = (init_val_, body_fun_)

~/.virtualenvs/ergodic/lib/python3.13/site-packages/equinox/internal/_loop/common.py in ?()
--> 471         buffer_val2 = body_fun(buffer_val)
    472         # Needed to work with `disable_jit`, as then we lose the automatic

~/.virtualenvs/ergodic/lib/python3.13/site-packages/diffrax/_integrate.py in ?()
--> 635         new_state, _, _ = body_fun_aux(state)

~/.virtualenvs/ergodic/lib/python3.13/site-packages/diffrax/_integrate.py in ?()
--> 349         (y, y_error, dense_info, solver_state, solver_result) = solver.step(
    350             terms,

~/.virtualenvs/ergodic/lib/python3.13/site-packages/diffrax/_solver/euler.py in ?()
---> 60         y1 = (y0**ω + terms.vf_prod(t0, y0, args, control) ** ω).ω

~/.virtualenvs/ergodic/lib/python3.13/site-packages/diffrax/_term.py in ?()
--> 756         return self.term.vf_prod(t, y, args, control)

~/.virtualenvs/ergodic/lib/python3.13/site-packages/diffrax/_term.py in ?()
--> 157         return self.prod(self.vf(t, y, args), control)

~/.virtualenvs/ergodic/lib/python3.13/site-packages/diffrax/_term.py in ?()
--> 194         out = self.vector_field(t, y, args)
    195         if jtu.tree_structure(out) != jtu.tree_structure(y):

/var/folders/7l/46fqh6j56rv1n6xbrntn3bmm0000gn/T/ipykernel_3478/3052998300.py in ?()
---> 11     return jnp.interp(u, xp, fp)

~/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/numpy/lax_numpy.py in ?()
-> 2752   return jitted_interp(x, xp, fp, left, right, period)
   2753 

~/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/numpy/lax_numpy.py in ?()
-> 2668   i = clip(searchsorted(xp_arr, x_arr, side='right'), 1, len(xp_arr) - 1)
   2669   df = fp_arr[i] - fp_arr[i - 1]

~/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/numpy/lax_numpy.py in ?()
-> 9992   return impl(a, v, side, dtype)  # type: ignore
   9993 

~/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/numpy/vectorize.py in ?()
--> 346     result = vectorized_func(*squeezed_args)

~/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/numpy/vectorize.py in ?()
--> 144     out = func(*args)

~/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/numpy/vectorize.py in ?()
--> 187     return func(*args, **kwargs, **static_kwargs)

~/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/numpy/lax_numpy.py in ?()
-> 9881   carry, _ = lax.scan(body_fun, init, (), length=n_levels,
   9882                       unroll=n_levels if unrolled else 1)

~/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/numpy/lax_numpy.py in ?()
-> 9877     go_left = op(query, sorted_arr[mid])

JaxStackTraceBeforeTransformation: FloatingPointError: invalid value (nan) encountered in broadcast_in_dim

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

FloatingPointError                        Traceback (most recent call last)
Cell In[1], line 13
     10 def vector_field(t, u, args):
     11     return jnp.interp(u, xp, fp)
---> 13 result = diffrax.diffeqsolve(diffrax.ODETerm(vector_field),
     14                             diffrax.Euler(),
     15                             t0=0.0,
     16                             t1=1.0,
     17                             dt0=0.01,
     18                             y0=jnp.array([0.5, 1.5]))

    [... skipping hidden 2 frame]

File ~/.virtualenvs/ergodic/lib/python3.13/site-packages/equinox/_jit.py:55, in _filter_jit_cache.<locals>.fun_wrapped(dynamic_donate, dynamic_nodonate, static)
     53 *args, dummy_arg = (first_arg,) + rest_args
     54 assert dummy_arg is None
---> 55 out = fun(*args, **kwargs)
     56 dynamic_out, static_out = partition(out, is_array)
     57 marker = jnp.array(0)

File ~/.virtualenvs/ergodic/lib/python3.13/site-packages/diffrax/_integrate.py:1416, in diffeqsolve(terms, solver, t0, t1, dt0, y0, args, saveat, stepsize_controller, adjoint, event, max_steps, throw, progress_meter, solver_state, controller_state, made_jump, discrete_terminating_event)
   1389 init_state = State(
   1390     y=y0,
   1391     tprev=tprev,
   (...)   1409     event_mask=event_mask,
   1410 )
   1412 #
   1413 # Main loop
   1414 #
-> 1416 final_state, aux_stats = adjoint.loop(
   1417     args=args,
   1418     terms=terms,
   1419     solver=solver,
   1420     stepsize_controller=stepsize_controller,
   1421     event=event,
   1422     saveat=saveat,
   1423     t0=t0,
   1424     t1=t1,
   1425     dt0=dt0,
   1426     max_steps=max_steps,
   1427     init_state=init_state,
   1428     throw=throw,
   1429     passed_solver_state=passed_solver_state,
   1430     passed_controller_state=passed_controller_state,
   1431     progress_meter=progress_meter,
   1432 )
   1434 #
   1435 # Finish up
   1436 #
   1438 progress_meter.close(final_state.progress_meter_state)

    [... skipping hidden 1 frame]

File ~/.virtualenvs/ergodic/lib/python3.13/site-packages/diffrax/_adjoint.py:299, in RecursiveCheckpointAdjoint.loop(***failed resolving arguments***)
    295     outer_while_loop = ft.partial(
    296         _outer_loop, kind="checkpointed", checkpoints=self.checkpoints
    297     )
    298     msg = None
--> 299 final_state = self._loop(
    300     terms=terms,
    301     saveat=saveat,
    302     init_state=init_state,
    303     max_steps=max_steps,
    304     inner_while_loop=inner_while_loop,
    305     outer_while_loop=outer_while_loop,
    306     **kwargs,
    307 )
    308 if msg is not None:
    309     final_state = eqxi.nondifferentiable_backward(
    310         final_state, msg=msg, symbolic=True
    311     )

File ~/.virtualenvs/ergodic/lib/python3.13/site-packages/diffrax/_integrate.py:638, in loop(solver, stepsize_controller, event, saveat, t0, t1, dt0, max_steps, terms, args, init_state, inner_while_loop, outer_while_loop, progress_meter)
    635     new_state, _, _ = body_fun_aux(state)
    636     return new_state
--> 638 final_state = outer_while_loop(
    639     cond_fun, body_fun, init_state, max_steps=max_steps, buffers=_outer_buffers
    640 )
    641 result = final_state.result
    643 if event is None or event.root_finder is None:

File /opt/homebrew/Cellar/[email protected]/3.13.0_1/Frameworks/Python.framework/Versions/3.13/lib/python3.13/contextlib.py:85, in ContextDecorator.__call__.<locals>.inner(*args, **kwds)
     82 @wraps(func)
     83 def inner(*args, **kwds):
     84     with self._recreate_cm():
---> 85         return func(*args, **kwds)

File ~/.virtualenvs/ergodic/lib/python3.13/site-packages/equinox/internal/_loop/loop.py:107, in while_loop(***failed resolving arguments***)
    105 elif kind == "checkpointed":
    106     del kind, base
--> 107     return checkpointed_while_loop(
    108         cond_fun,
    109         body_fun,
    110         init_val,
    111         max_steps=max_steps,
    112         buffers=buffers,
    113         checkpoints=checkpoints,
    114     )
    115 elif kind == "bounded":
    116     del kind, checkpoints

File ~/.virtualenvs/ergodic/lib/python3.13/site-packages/equinox/internal/_loop/checkpointed.py:249, in checkpointed_while_loop(***failed resolving arguments***)
    247 body_fun_ = filter_closure_convert(body_fun_, init_val_)
    248 vjp_arg = (init_val_, body_fun_)
--> 249 final_val_ = _checkpointed_while_loop(
    250     vjp_arg, cond_fun_, checkpoints, buffers_, max_steps
    251 )
    252 _, _, _, final_val = _stop_gradient_on_unperturbed(init_val_, final_val_, body_fun_)
    253 return final_val

File ~/.virtualenvs/ergodic/lib/python3.13/site-packages/equinox/_ad.py:1107, in filter_custom_vjp.__call__(self, vjp_arg, *args, **kwargs)
   1103 array_args_kwargs, nonarray_args_kwargs = partition((args, kwargs), is_array)
   1104 array_args_kwargs = nondifferentiable(
   1105     array_args_kwargs, name="`*args` and `**kwargs` to `filter_custom_vjp`"
   1106 )
-> 1107 out = self.fn_wrapped(
   1108     nonarray_vjp_arg,
   1109     nonarray_args_kwargs,
   1110     diff_array_vjp_arg,
   1111     nondiff_array_vjp_arg,
   1112     array_args_kwargs,
   1113 )
   1114 diff_array_out, nondiff_array_out, nonarray_out = out
   1115 return combine(diff_array_out, nondiff_array_out, nonarray_out.value)

    [... skipping hidden 11 frame]

File ~/.virtualenvs/ergodic/lib/python3.13/site-packages/equinox/internal/_loop/checkpointed.py:271, in _checkpointed_while_loop(***failed resolving arguments***)
    268 while_loop = jax.named_call(lax.while_loop, name="checkpointed-no-vjp")
    269 # Hashable wrapper; JAX issue #13554 and
    270 # https://github.com/patrick-kidger/equinox/issues/768
--> 271 return while_loop(lambda x: cond_fun(x), lambda x: body_fun(x), init_val)

File /opt/homebrew/Cellar/[email protected]/3.13.0_1/Frameworks/Python.framework/Versions/3.13/lib/python3.13/contextlib.py:85, in ContextDecorator.__call__.<locals>.inner(*args, **kwds)
     82 @wraps(func)
     83 def inner(*args, **kwds):
     84     with self._recreate_cm():
---> 85         return func(*args, **kwds)

    [... skipping hidden 2 frame]

File ~/.virtualenvs/ergodic/lib/python3.13/site-packages/equinox/internal/_loop/checkpointed.py:271, in _checkpointed_while_loop.<locals>.<lambda>(x)
    268 while_loop = jax.named_call(lax.while_loop, name="checkpointed-no-vjp")
    269 # Hashable wrapper; JAX issue #13554 and
    270 # https://github.com/patrick-kidger/equinox/issues/768
--> 271 return while_loop(lambda x: cond_fun(x), lambda x: body_fun(x), init_val)

    [... skipping hidden 8 frame]

File ~/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/api.py:121, in _nan_check_posthook(fun, args, kwargs, output)
    119 f = fun._fun
    120 if getattr(f, '_apply_primitive', False):
--> 121   raise FloatingPointError(f"invalid value ({e.ty}) encountered in {f.__qualname__}") from None
    122 # compiled_fun can only raise in this case
    123 dispatch.maybe_recursive_nan_check(e, f, args, kwargs)

FloatingPointError: invalid value (nan) encountered in broadcast_in_dim

@aidancrilly who first noticed this issue

jpbrodrick89 avatar May 08 '25 14:05 jpbrodrick89

With which versions of JAX, equinox and diffrax do you get this error?

johannahaffner avatar May 08 '25 15:05 johannahaffner

jax version: 0.6.0 equinox version: 0.12.1 diffrax version: 0.7.0

jpbrodrick89 avatar May 08 '25 16:05 jpbrodrick89

I can reproduce it with this JAX-only MWE:

import os


os.environ["JAX_DEBUG_NANS"] = "1"
os.environ["JAX_DISABLE_JIT"] = "1"

import jax
import jax.core
import jax.numpy as jnp


t = jnp.array([0.5, 1.5])
xp = jnp.array([0,1,2])
fp = jnp.array([0.0,1.0,0.0])
jaxpr = jax.make_jaxpr(lambda t: jnp.interp(t, xp, fp))(t)
jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, t)

It looks like the offending culprit is this line here:

https://github.com/jax-ml/jax/blob/8c987bfb0afd36710038a96430ae95f473503e08/jax/_src/lax/lax.py#L7923

Whilst it is wrapped in a debug_nans(False) during the initial tracing into a jaxpr, this context is no longer present during the subsequent evaluation.

I'm not immediately sure what would be a reasonable workaround from your perspective, though.

patrick-kidger avatar May 08 '25 16:05 patrick-kidger

Thanks, do you know why it depends on adjoint?

jpbrodrick89 avatar May 08 '25 17:05 jpbrodrick89

Simply that one of the adjoint methods includes code of the above sort whilst the others do not, I think.

(I'm tagging this as a question rather than a bug, as sadly I don't think there's anything wrong occuring in Diffrax here / anything we can do to work around it in Diffrax.)

patrick-kidger avatar May 08 '25 19:05 patrick-kidger

Thanks Patrick, I did some digging around and noticed there is some commentary about this in jax._src.core.jaxpr_as_fun. I'm guessing there's no clever way we could use jaxpr_as_fun with the recursive checkpoint adjoint? It sounds like there may be ambitions to add contexts to jaxpr's one day which should solve the problem. Do you think it's worth raising this as an issue on the jax side? My gut feeling is that the fix is not at all trivial there either.

For posteriority, if anyone encounters something similar in the future, my current ways to do useful debugging when encountering these issues are to:

  1. reduce max_steps considerably and use DirectAdjoint (to avoid memory issues), or
  2. just manually run a single step in isolation without diffrax. This enabled me to debug a nasty reverse pass nan I wasn't able to pin down otherwise. (Unfortunately, debugging nan's without disabling jit typically gives very opaque error message in my limited experience.)

jpbrodrick89 avatar May 10 '25 02:05 jpbrodrick89

From what I've seen I suspect this is probably a wontfix from the JAX side, although I don't wish to speak on behalf of their team.

patrick-kidger avatar May 11 '25 22:05 patrick-kidger