diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

sol.ts contains wrong values in some cases

Open terhorst opened this issue 6 months ago • 1 comments

When using saveat= and jump_ts=, the return values for sol.ts can be wrong/infinite. Here is a reproducible example:

import jax.numpy as jnp
import jax
import diffrax as dfx

Q = jnp.array([[-1., 1.], [0., 0.]])
y0 = jnp.array([0.5, 0.5])

def A(t, y, _):
    return Q @ y

solver = dfx.Kvaerno3()

def f(t1):
    saveat = dfx.SaveAt(t1=True, ts=[t1])
    ssc = dfx.PIDController(atol=1e-6, rtol=1e-6, jump_ts=jnp.array([0., t1]))

    res = dfx.diffeqsolve(
        dfx.ODETerm(A),
        solver=solver,
        y0=y0,
        t0=0.,
        t1=t1,
        dt0=0.01,
        stepsize_controller=ssc,
        saveat=saveat
    )
    return res.ts

for t1 in [1e0, 1e3]:
    print(t1, f(t1))


# 1.0 [1. 1.]
# 1000.0 [1000.   inf]

The expected output is [t1, t1] regardless of what t1 is. However, for certain values, the saved timepoints are erroneously [t1, inf]. The behavior seems to depend on the magnitude of t1, so I suspect maybe a jnp.nextafter-ish type of bug.

This occurs on HEAD as well as the latest release.

Might be related to #607.

terhorst avatar Jul 07 '25 19:07 terhorst

Thanks for the report! It turns out that the upcoming #660 will solve this - we're just about to land that.

You should be able to install Diffrax from #660 as a workaround until the next release, and I've also just opened #665 with a test for the issue you've found here.

patrick-kidger avatar Jul 13 '25 10:07 patrick-kidger