diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

Could you give an example of interactively step through a implicit solver?

Open chaoming0625 opened this issue 7 months ago • 5 comments

In this example, https://docs.kidger.site/diffrax/usage/manual-stepping/, how to replace the explicit solver with implicit ones? Here is my try, but got the following error:


import jax.numpy as jnp
import diffrax
from diffrax import ODETerm

vector_field = lambda t, y, args: -y
term = ODETerm(vector_field)
solver = diffrax.ImplicitEuler()

t0 = 0
dt0 = 0.05
t1 = 1
y0 = jnp.array(1.0)
args = None

tprev = t0
tnext = t0 + dt0
y = y0
state = solver.init(term, tprev, tnext, y0, args)

while tprev < t1:
    y, _, _, state, _ = solver.step(term, tprev, tnext, y, args, state, made_jump=False)
    print(f"At time {tnext} obtained value {y}")
    tprev = tnext
    tnext = min(tprev + dt0, t1)


The error message:


    raise RuntimeError("Type of `other` not understood.")
RuntimeError: Type of `other` not understood.

chaoming0625 avatar May 21 '25 14:05 chaoming0625

You need to specify the tolerances for the root finder for the solver:

solver = diffrax.ImplicitEuler(root_finder=diffrax.VeryChord(rtol=1e-8, atol=1e-8))

What's going on here is that by default we set the tolerances of the solver to that used by the adaptive stepsize controller, as part of diffeqsolve. As you're skipping using diffeqsolve (and an adaptive step size controller, which is normally needed for an implicit solver), then you need to set the tolerances manually... and likewise skip the nice error message describing what's going on!

patrick-kidger avatar May 21 '25 20:05 patrick-kidger

Thanks. This works.

BTW, I am also curious about how to use Kvaerno3. I got the error:


ValueError: Closure-converted function called with different dynamic arguments to the example arguments provided:

Called with: ((f32[], f32[], None), {})

Closure-converted with: ((None, f32[], None), {})

chaoming0625 avatar May 22 '25 13:05 chaoming0625

Wrapping into floating point arrays fixes this one:

t0 = jnp.array(0.)
dt0 = jnp.array(0.05)
t1 = jnp.array(1.)
y0 = jnp.array(1.0)

You're aptly demonstrating that our step-by-step API needs some better error messages / better handling of edge cases! Since we pretty much always call it via diffeqsolve instead then most of the nice error messages are living up there instead. As such I'm going to mark this as a feature.

patrick-kidger avatar May 22 '25 18:05 patrick-kidger

This solves the issue. Thanks.

chaoming0625 avatar May 23 '25 11:05 chaoming0625