diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

Cannot reproduce example from stiff ode

Open Rumoa opened this issue 2 months ago • 1 comments

Hello, I am trying to use stiff ode solvers like Kvaerno5. When running the code from the example from docs:

import time

import diffrax
import equinox as eqx  # https://github.com/patrick-kidger/equinox
import jax
import jax.numpy as jnp


class Robertson(eqx.Module):
    k1: float
    k2: float
    k3: float

    def __call__(self, t, y, args):
        f0 = -self.k1 * y[0] + self.k3 * y[1] * y[2]
        f1 = self.k1 * y[0] - self.k2 * y[1] ** 2 - self.k3 * y[1] * y[2]
        f2 = self.k2 * y[1] ** 2
        return jnp.stack([f0, f1, f2])


@jax.jit
def main(k1, k2, k3):
    robertson = Robertson(k1, k2, k3)
    terms = diffrax.ODETerm(robertson)
    t0 = 0.0
    t1 = 100.0
    y0 = jnp.array([1.0, 0.0, 0.0])
    dt0 = 0.0002
    solver = diffrax.Kvaerno5()
    saveat = diffrax.SaveAt(ts=jnp.array([0.0, 1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2]))
    stepsize_controller = diffrax.PIDController(rtol=1e-8, atol=1e-8)
    sol = diffrax.diffeqsolve(
        terms,
        solver,
        t0,
        t1,
        dt0,
        y0,
        saveat=saveat,
        stepsize_controller=stepsize_controller,
    )
    return sol


main(0.04, 3e7, 1e4)

start = time.time()
sol = main(0.04, 3e7, 1e4)
end = time.time()

print("Results:")
for ti, yi in zip(sol.ts, sol.ys):
    print(f"t={ti.item()}, y={yi.tolist()}")
print(f"Took {sol.stats['num_steps']} steps in {end - start} seconds.")

I get the following error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[91], line 45
     32     sol = diffrax.diffeqsolve(
     33         terms,
     34         solver,
   (...)
     40         stepsize_controller=stepsize_controller,
     41     )
     42     return sol
---> 45 main(0.04, 3e7, 1e4)
     47 start = time.time()
     48 sol = main(0.04, 3e7, 1e4)

    [... skipping hidden 13 frame]

Cell In[91], line 32, in main(k1, k2, k3)
     30 saveat = diffrax.SaveAt(ts=jnp.array([0.0, 1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2]))
     31 stepsize_controller = diffrax.PIDController(rtol=1e-8, atol=1e-8)
---> 32 sol = diffrax.diffeqsolve(
     33     terms,
     34     solver,
     35     t0,
     36     t1,
     37     dt0,
     38     y0,
     39     saveat=saveat,
     40     stepsize_controller=stepsize_controller,
     41 )
     42 return sol

    [... skipping hidden 18 frame]

File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/diffrax/_integrate.py:1416, in diffeqsolve(terms, solver, t0, t1, dt0, y0, args, saveat, stepsize_controller, adjoint, event, max_steps, throw, progress_meter, solver_state, controller_state, made_jump, discrete_terminating_event)
   1389 init_state = State(
   1390     y=y0,
   1391     tprev=tprev,
   (...)
   1409     event_mask=event_mask,
   1410 )
   1412 #
   1413 # Main loop
   1414 #
-> 1416 final_state, aux_stats = adjoint.loop(
   1417     args=args,
   1418     terms=terms,
   1419     solver=solver,
   1420     stepsize_controller=stepsize_controller,
   1421     event=event,
   1422     saveat=saveat,
   1423     t0=t0,
   1424     t1=t1,
   1425     dt0=dt0,
   1426     max_steps=max_steps,
   1427     init_state=init_state,
   1428     throw=throw,
   1429     passed_solver_state=passed_solver_state,
   1430     passed_controller_state=passed_controller_state,
   1431     progress_meter=progress_meter,
   1432 )
   1434 #
   1435 # Finish up
   1436 #
   1438 progress_meter.close(final_state.progress_meter_state)

    [... skipping hidden 1 frame]

File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/diffrax/_adjoint.py:299, in RecursiveCheckpointAdjoint.loop(***failed resolving arguments***)
    295     outer_while_loop = ft.partial(
    296         _outer_loop, kind="checkpointed", checkpoints=self.checkpoints
    297     )
    298     msg = None
--> 299 final_state = self._loop(
    300     terms=terms,
    301     saveat=saveat,
    302     init_state=init_state,
    303     max_steps=max_steps,
    304     inner_while_loop=inner_while_loop,
    305     outer_while_loop=outer_while_loop,
    306     **kwargs,
    307 )
    308 if msg is not None:
    309     final_state = eqxi.nondifferentiable_backward(
    310         final_state, msg=msg, symbolic=True
    311     )

File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/diffrax/_integrate.py:619, in loop(solver, stepsize_controller, event, saveat, t0, t1, dt0, max_steps, terms, args, init_state, inner_while_loop, outer_while_loop, progress_meter)
    617 static_made_jump = init_state.made_jump
    618 static_result = init_state.result
--> 619 _, traced_jump, traced_result = eqx.filter_eval_shape(body_fun_aux, init_state)
    620 if traced_jump:
    621     static_made_jump = None

    [... skipping hidden 16 frame]

File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/diffrax/_integrate.py:349, in loop.<locals>.body_fun_aux(state)
    342 state = _handle_static(state)
    344 #
    345 # Actually do some differential equation solving! Make numerical steps, adapt
    346 # step sizes, all that jazz.
    347 #
--> 349 (y, y_error, dense_info, solver_state, solver_result) = solver.step(
    350     terms,
    351     state.tprev,
    352     state.tnext,
    353     state.y,
    354     args,
    355     state.solver_state,
    356     state.made_jump,
    357 )
    359 # e.g. if someone has a sqrt(y) in the vector field, and dt0 is so large that
    360 # we get a negative value for y, and then get a NaN vector field. (And then
    361 # everything breaks.) See #143.
    362 y_error = jtu.tree_map(lambda x: jnp.where(jnp.isnan(x), jnp.inf, x), y_error)

    [... skipping hidden 1 frame]

File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/diffrax/_solver/runge_kutta.py:1149, in AbstractRungeKutta.step(***failed resolving arguments***)
   1142 const_result = const_result_sentinel = object()
   1143 # Needs to be an `eqxi.while_loop` as:
   1144 # (a) we may have variable length: e.g. an FSAL explicit RK scheme will have one
   1145 #     more stage on the first step.
   1146 # (b) to work around a limitation of JAX's autodiff being unable to express
   1147 #     "triangular computations" (every stage depends on all previous stages)
   1148 #     without spurious copies.
-> 1149 final_val = eqxi.while_loop(
   1150     cond_stage,
   1151     rk_stage,
   1152     init_val,
   1153     max_steps=num_stages,
   1154     buffers=buffers,
   1155     kind="checkpointed" if self.scan_kind is None else self.scan_kind,
   1156     checkpoints=num_stages,
   1157     base=num_stages,
   1158 )
   1159 _, y1, f1_for_fsal, _, _, fs, ks, result = final_val
   1160 assert const_result is not const_result_sentinel

File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/equinox/internal/_loop/loop.py:107, in while_loop(***failed resolving arguments***)
    105 elif kind == "checkpointed":
    106     del kind, base
--> 107     return checkpointed_while_loop(
    108         cond_fun,
    109         body_fun,
    110         init_val,
    111         max_steps=max_steps,
    112         buffers=buffers,
    113         checkpoints=checkpoints,
    114     )
    115 elif kind == "bounded":
    116     del kind, checkpoints

File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/equinox/internal/_loop/checkpointed.py:247, in checkpointed_while_loop(***failed resolving arguments***)
    245 cond_fun_ = filter_closure_convert(cond_fun_, init_val_)
    246 cond_fun_ = jtu.tree_map(_stop_gradient, cond_fun_)
--> 247 body_fun_ = filter_closure_convert(body_fun_, init_val_)
    248 vjp_arg = (init_val_, body_fun_)
    249 final_val_ = _checkpointed_while_loop(
    250     vjp_arg, cond_fun_, checkpoints, buffers_, max_steps
    251 )

    [... skipping hidden 17 frame]

File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/equinox/internal/_loop/common.py:474, in common_rewrite.<locals>.new_body_fun(val)
    472 step, pred, _, val = val
    473 buffer_val = _wrap_buffers(val, pred, tag)
--> 474 buffer_val2 = body_fun(buffer_val)
    475 # Needed to work with `disable_jit`, as then we lose the automatic
    476 # ArrayLike->Array cast provided by JAX's while loops.
    477 # The input `val` is already cast to Array below, so this matches that.
    478 buffer_val2 = jtu.tree_map(fixed_asarray, buffer_val2)

File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/diffrax/_solver/runge_kutta.py:984, in AbstractRungeKutta.step.<locals>.rk_stage(val)
    982 if eval_fs:
    983     jac_f = eqxi.nondifferentiable(jac_f, name="jac_f")
--> 984     nonlinear_sol = optx.root_find(
    985         _implicit_relation_f,
    986         self.root_finder,  # pyright: ignore
    987         f_pred,
    988         f_implicit_args,
    989         options=dict(init_state=jac_f),
    990         throw=False,
    991         max_steps=self.root_find_max_steps,  # pyright: ignore
    992     )
    993     implicit_fi = nonlinear_sol.value
    994     implicit_ki = _unused

    [... skipping hidden 18 frame]

File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/optimistix/_root_find.py:220, in root_find(fn, solver, y0, args, options, has_aux, max_steps, adjoint, throw, tags)
    218 if options is None:
    219     options = {}
--> 220 return iterative_solve(
    221     fn,
    222     solver,
    223     y0,
    224     args,
    225     options,
    226     max_steps=max_steps,
    227     adjoint=adjoint,
    228     throw=throw,
    229     tags=tags,
    230     f_struct=f_struct,
    231     aux_struct=aux_struct,
    232     rewrite_fn=_rewrite_fn,
    233 )

File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/optimistix/_iterate.py:346, in iterative_solve(fn, solver, y0, args, options, max_steps, adjoint, throw, tags, f_struct, aux_struct, rewrite_fn)
    334 aux_struct = jtu.tree_map(eqxi.Static, aux_struct)
    335 inputs = fn, solver, y0, args, options, max_steps, f_struct, aux_struct, tags
    336 (
    337     out,
    338     (
    339         num_steps,
    340         result,
    341         dynamic_final_state,
    342         static_state,
    343         aux,
    344         stats,
    345     ),
--> 346 ) = adjoint.apply(_iterate, rewrite_fn, inputs, tags)
    347 final_state = eqx.combine(dynamic_final_state, unwrap_jaxpr(static_state.value))
    348 stats = {"num_steps": num_steps, "max_steps": max_steps, **stats}

    [... skipping hidden 1 frame]

File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/optimistix/_adjoint.py:134, in ImplicitAdjoint.apply(self, primal_fn, rewrite_fn, inputs, tags)
    132 def apply(self, primal_fn, rewrite_fn, inputs, tags):
    133     inputs = inputs + (ft.partial(eqxi.while_loop, kind="lax"),)
--> 134     return implicit_jvp(primal_fn, rewrite_fn, inputs, tags, self.linear_solver)

File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/optimistix/_ad.py:60, in implicit_jvp(fn_primal, fn_rewrite, inputs, tags, linear_solver)
     58 assert _is_global_function(fn_primal)
     59 assert _is_global_function(fn_rewrite)
---> 60 root, residual = _implicit_impl(fn_primal, fn_rewrite, inputs, tags, linear_solver)
     61 return root, jtu.tree_map(eqxi.nondifferentiable_backward, residual)

    [... skipping hidden 14 frame]

File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/optimistix/_ad.py:67, in _implicit_impl(***failed resolving arguments***)
     64 @eqx.filter_custom_jvp
     65 def _implicit_impl(fn_primal, fn_rewrite, inputs, tags, linear_solver):
     66     del fn_rewrite, tags, linear_solver
---> 67     return jtu.tree_map(jnp.asarray, fn_primal(inputs))

File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/optimistix/_iterate.py:242, in _iterate(***failed resolving arguments***)
    239     _, _, state, _ = carry
    240     return solver.buffers(state)
--> 242 final_carry = while_loop(cond_fun, body_fun, init_carry, max_steps=max_steps)
    243 final_y, num_steps, dynamic_final_state, final_aux = final_carry
    244 final_state = eqx.combine(static_state, dynamic_final_state)

File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/equinox/internal/_loop/loop.py:103, in while_loop(***failed resolving arguments***)
     99     cond_fun_, body_fun_, init_val_, _ = common_rewrite(
    100         cond_fun, body_fun, init_val, max_steps, buffers, makes_false_steps=False
    101     )
    102     del cond_fun, body_fun, init_val
--> 103     _, _, _, final_val = lax.while_loop(cond_fun_, body_fun_, init_val_)
    104     return final_val
    105 elif kind == "checkpointed":

    [... skipping hidden 10 frame]

File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/equinox/internal/_loop/common.py:474, in common_rewrite.<locals>.new_body_fun(val)
    472 step, pred, _, val = val
    473 buffer_val = _wrap_buffers(val, pred, tag)
--> 474 buffer_val2 = body_fun(buffer_val)
    475 # Needed to work with `disable_jit`, as then we lose the automatic
    476 # ArrayLike->Array cast provided by JAX's while loops.
    477 # The input `val` is already cast to Array below, so this matches that.
    478 buffer_val2 = jtu.tree_map(fixed_asarray, buffer_val2)

File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/optimistix/_iterate.py:232, in _iterate.<locals>.body_fun(carry)
    230 y, num_steps, dynamic_state, _ = carry
    231 state = eqx.combine(static_state, dynamic_state)
--> 232 new_y, new_state, aux = solver.step(fn, y, args, options, state, tags)
    233 new_dynamic_state, new_static_state = eqx.partition(new_state, eqx.is_array)
    235 assert eqx.tree_equal(static_state, new_static_state) is True

    [... skipping hidden 1 frame]

File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/diffrax/_root_finder/_verychord.py:127, in VeryChord.step(***failed resolving arguments***)
    125 jac, linear_state = state.linear_state
    126 linear_state = lax.stop_gradient(linear_state)
--> 127 sol = lx.linear_solve(
    128     jac, fx, self.linear_solver, state=linear_state, throw=False
    129 )
    130 diff = sol.value
    131 new_y = (y**ω - diff**ω).ω

    [... skipping hidden 18 frame]

File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/lineax/_solve.py:810, in linear_solve(operator, vector, solver, options, state, throw)
    804 options = eqxi.nondifferentiable(
    805     options, name="`lineax.linear_solve(..., options=...)`"
    806 )
    807 solver = eqxi.nondifferentiable(
    808     solver, name="`lineax.linear_solve(..., solver=...)`"
    809 )
--> 810 solution, result, stats = eqxi.filter_primitive_bind(
    811     linear_solve_p, operator, state, vector, options, solver, throw
    812 )
    813 # TODO: prevent forward-mode autodiff through stats
    814 stats = eqxi.nondifferentiable_backward(stats)

File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/equinox/internal/_primitive.py:273, in filter_primitive_bind(prim, *args)
    271 static = tuple(_missing_dynamic if is_array(x) else x for x in flat)
    272 flatten = Flatten()
--> 273 flat_out = prim.bind(*dynamic, treedef=treedef, static=static, flatten=flatten)
    274 treedef_out, static_out = flatten.get()
    275 return combine(jtu.tree_unflatten(treedef_out, flat_out), static_out)

    [... skipping hidden 5 frame]

File ~/miniconda3/envs/jax-gr/lib/python3.12/site-packages/jax/_src/util.py:465, in multi_weakref_lru_cache.<locals>.wrapper(*orig_args, **orig_kwargs)
    462   return cached_call(acc_weakrefs[0],
    463                      *args, **kwargs)
    464 else:
--> 465   value_to_weakref = {v: weakref.ref(v, remove_weakref)
    466                       for v in set(acc_weakrefs)}
    467   key = MultiWeakRefCacheKey(weakrefs=tuple(value_to_weakref[v]
    468                                             for v in acc_weakrefs))
    469   return cached_call(key, *args, **kwargs)

TypeError: cannot create weak reference to 'Flatten' object

The version of jax that I am using is 0.7.1 and diffrax is 0.7.0

Rumoa avatar Oct 29 '25 16:10 Rumoa

Upgrading to Equinox >= 0.13.1 should fix this, I think!

johannahaffner avatar Oct 29 '25 21:10 johannahaffner