sol.ts contains wrong values in some cases
When using saveat= and jump_ts=, the return values for sol.ts can be wrong/infinite. Here is a reproducible example:
import jax.numpy as jnp
import jax
import diffrax as dfx
Q = jnp.array([[-1., 1.], [0., 0.]])
y0 = jnp.array([0.5, 0.5])
def A(t, y, _):
return Q @ y
solver = dfx.Kvaerno3()
def f(t1):
saveat = dfx.SaveAt(t1=True, ts=[t1])
ssc = dfx.PIDController(atol=1e-6, rtol=1e-6, jump_ts=jnp.array([0., t1]))
res = dfx.diffeqsolve(
dfx.ODETerm(A),
solver=solver,
y0=y0,
t0=0.,
t1=t1,
dt0=0.01,
stepsize_controller=ssc,
saveat=saveat
)
return res.ts
for t1 in [1e0, 1e3]:
print(t1, f(t1))
# 1.0 [1. 1.]
# 1000.0 [1000. inf]
The expected output is [t1, t1] regardless of what t1 is. However, for certain values, the saved timepoints are erroneously [t1, inf]. The behavior seems to depend on the magnitude of t1, so I suspect maybe a jnp.nextafter-ish type of bug.
This occurs on HEAD as well as the latest release.
Might be related to #607.
Thanks for the report! It turns out that the upcoming #660 will solve this - we're just about to land that.
You should be able to install Diffrax from #660 as a workaround until the next release, and I've also just opened #665 with a test for the issue you've found here.