diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

Improve error messages

Open jpbrodrick89 opened this issue 4 months ago • 1 comments

diffeqsolve traces vf from ODETerms before running to check that shapes are compatible, however when a user has a bug in their vf traversing the stack trace can be quite cumbersome (typically the problem is about 30-50% through the stack trace). The lowest error shown is just that terms are not compatible which is not helpful, scrawling through this when y0 is a complicated pytree and vf a complicated eqx.Module can be quite cumbersome. When an error occurs is it possible to exit earlier/truncate the unnecessary diffeqsolve stack trace?

MWE:

import jax.numpy as jnp
import diffrax

def f(t, y, args):
    return x

diffrax.diffeqsolve(diffrax.ODETerm(vf), diffrax.Euler(), 0.0, 1.0, 0.1, jnp.zeros(1))
Gives the following stack trace
Traceback (most recent call last):
  File "/Users/jonathanbrodrick/pasteurcodes/diffrax/diffrax/_integrate.py", line 165, in _check
    vf_type = eqx.filter_eval_shape(term.vf, t, yi, args)
  File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/equinox/_eval_shape.py", line 38, in filter_eval_shape
    dynamic_out, static_out = jax.eval_shape(ft.partial(_fn, static), dynamic)
                              ~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/traceback_util.py", line 182, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/api.py", line 3014, in eval_shape
    return jit(fun).eval_shape(*args, **kwargs)
           ~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/traceback_util.py", line 182, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/pjit.py", line 352, in jit_eval_shape
    p, _ = _infer_params(jit_func._fun, jit_func._jit_info, args, kwargs)
           ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/pjit.py", line 686, in _infer_params
    return _infer_params_internal(fun, ji, args, kwargs)
  File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/pjit.py", line 710, in _infer_params_internal
    p, args_flat = _infer_params_impl(
                   ~~~~~~~~~~~~~~~~~~^
        fun, ji, ctx_mesh, dbg, args, kwargs, in_avals=avals)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/pjit.py", line 606, in _infer_params_impl
    jaxpr, consts, out_avals, attrs_tracked = _create_pjit_jaxpr(
                                              ~~~~~~~~~~~~~~~~~~^
        flat_fun, in_type, attr_token, IgnoreKey(ji.inline))
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/linear_util.py", line 471, in memoized_fun
    ans = call(fun, *args)
  File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/pjit.py", line 1414, in _create_pjit_jaxpr
    jaxpr, global_out_avals, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic(fun, in_type)
                                                     ~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^
  File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/profiler.py", line 354, in wrapper
    return func(*args, **kwargs)
  File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/interpreters/partial_eval.py", line 2292, in trace_to_jaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
  File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/linear_util.py", line 211, in call_wrapped
    return self.f_transformed(*args, **kwargs)
           ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/api_util.py", line 288, in _argnums_partial
    return _fun(*args, **kwargs)
  File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/api_util.py", line 73, in flatten_fun
    ans = f(*py_args, **py_kwargs)
  File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/linear_util.py", line 396, in _get_result_paths_thunk
    ans = _fun(*args, **kwargs)
  File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/equinox/_eval_shape.py", line 33, in _fn
    _out = _fun(*_args, **_kwargs)
  File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/equinox/_module.py", line 1060, in __call__
    return self.__func__(self.__self__, *args, **kwargs)
           ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jonathanbrodrick/pasteurcodes/diffrax/diffrax/_term.py", line 194, in vf
    out = self.vector_field(t, y, args)
  File "<python-input-3>", line 2, in vf
    return x
           ^
NameError: name 'x' is not defined

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/Users/jonathanbrodrick/pasteurcodes/diffrax/diffrax/_integrate.py", line 195, in _assert_term_compatible
    jtu.tree_map(_check, term_structure, terms, contr_kwargs, y)
    ~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/tree_util.py", line 362, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
           ~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/jax/_src/tree_util.py", line 362, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
                             ~^^^^^
  File "/Users/jonathanbrodrick/pasteurcodes/diffrax/diffrax/_integrate.py", line 167, in _check
    raise ValueError(f"Error while tracing {term}.vf: " + str(e))
ValueError: Error while tracing ODETerm(vector_field=<function vf>).vf: name 'x' is not defined

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<python-input-3>", line 3, in <module>
    diffrax.diffeqsolve(diffrax.ODETerm(vf), diffrax.Euler(), 0.0, 1.0, 0.1, jnp.zeros(1))
    ~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/equinox/_jit.py", line 209, in __call__
    return _call(self, False, args, kwargs)
  File "/Users/jonathanbrodrick/.virtualenvs/ergodic/lib/python3.13/site-packages/equinox/_jit.py", line 263, in _call
    marker, _, _ = out = jit_wrapper._cached(
                         ~~~~~~~~~~~~~~~~~~~^
        dynamic_donate, dynamic_nodonate, static
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/Users/jonathanbrodrick/pasteurcodes/diffrax/diffrax/_integrate.py", line 1117, in diffeqsolve
    _assert_term_compatible(
    ~~~~~~~~~~~~~~~~~~~~~~~^
        t0,
        ^^^
    ...<4 lines>...
        solver.term_compatible_contr_kwargs,
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/Users/jonathanbrodrick/pasteurcodes/diffrax/diffrax/_integrate.py", line 200, in _assert_term_compatible
    raise ValueError(
    ...<3 lines>...
    ) from e
ValueError: Terms are not compatible with solver! Got:
ODETerm(vector_field=<function vf>)
but expected:
diffrax.AbstractTerm
Note that terms are checked recursively: if you scroll up you may find a root-cause error that is more specific.
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

It would be preferable to return the stack trace ending with NameError: name 'x' is not defined which is the real issue.

jpbrodrick89 avatar Aug 26 '25 13:08 jpbrodrick89

Thanks for the report! This makes sense to me. I've tweaked things in #682.

patrick-kidger avatar Aug 31 '25 11:08 patrick-kidger