Issue with small time steps
First off, thanks for this (and your other) libraries! Very useful for scientific programming with JAX.
I'm having an odd issue with the time-stepping. If the time step is too small, it looks like there is some kind of numerical error that creeps in. Here's a simple MWE with the ODE being exponential decay
from diffrax import diffeqsolve, Dopri5, Euler,ODETerm, SaveAt, ConstantStepSize
import jax
jax.config.update("jax_enable_x64", True)
vector_field = lambda t, y, args: -y / args[0]
term = ODETerm(vector_field)
solver = Euler()
saveat = SaveAt(t0=True, t1=True)
stepsize_controller = ConstantStepSize()
tau = 1e-11
dt0 = tau / 1000
sol = diffeqsolve(term, solver, t0=0, t1=tau, dt0=dt0, y0=1.0, saveat=saveat,
args=(tau,),
stepsize_controller=stepsize_controller, max_steps=10000)
print(sol.ts[0:10])
print(sol.ys[0:10])
It's set up so that when you set the decay time tau, it'll run for 1000 steps for one decay time, so the solution will always be 1/e.
For tau over about 1e-9, you get the right answer. If tau < 1e-9, you start getting some pretty bad numerical error. For instance, tau = 1e-10 gives the answer as 1e-4.
If I switch to Bosh3, the issue is less - but still, there is a tau dependence on the error
The same thing run through scipy.integrate.solve_ivp gives no tau dependence on the solution error (though of course, diffrax is like two orders of magnitude faster!). I jitted the input to scipy.integrate.solve_ivp and it made no difference - so I guess the issue is somewhere in diffrax?
Can I reparametrize the ODE I'm working with to avoid small numbers? Yes, but it does make me worry about numerical stability. Any advice?
Hey there! Thanks for the report. So it's interesting to also print out sol.stats["num_steps"] as tau varies: for large values we get 1000, but for small values we get only 2!
What's going on here is that if a step is very close to t1, then we actually clip it directly to t1:
https://github.com/patrick-kidger/diffrax/blob/5e5fed6c6abe116de35af835ee66febbbabcadaf/diffrax/_integrate.py#L266
And when the timescales involved are very small, then this triggers almost instantly.
This is needed in a few circumstances: in particular a lot of dense interpolation involves some kind of 'divide by the length of the step size', and this becomes numerically unstable for short step sizes. That can arise when performing adaptive step sizes, but an even more common case is when we keep incrementing our steps by dt0 -- but due to floating point error find ourselves a ULPs short of t1 on what would otherwise be the 'final' step.
FWIW at least in that latter case I can believe that we could find an alternate fix by setting constant steps using something like jnp.linspace(t0, t1, (t1 - t0) / dt0)[step_index] instead. I have had this scenario occur for adaptive step-sizing too though, so I think ideally there'd be some more general fix.
WDYT? I'd welcome any thoughts on a fix here.
Makes sense! I think using linspace for the constant step size case would be fine. I see where the issue is for adaptive though. Here's an idea - just change the logic for clipping to
clip = (t1 - tnext) / t1 < const * mach_eps
where mach_eps is the machine epsilon for single/doubles. Can use a const > 1 like 2 or something to give more padding.
In other words, just use a smaller threshold? FWIW this has issues when you use fairly large values, as the epsilon starts to become too small.
Ultimately I think what's needed here is to have a tolerance of e.g. '10 ULPs' rather than a numerical value. We could maybe test for that by casting the float's byte representation into an integer, although that seems pretty hacky...
Isn't @varchasgopalaswamy 's suggestion more akin to using a relative rather than an absolute tolerance so it will scale with t1, rearranging and renaming:
clip = tnext > (1 - rtol) * t1
If you really want to deal with extreme cases of underflow we could still keep a smaller atol:
clip = tnext > (1 - rtol) * t1 - atol
Another possibility, which might be more appropriate for non-zero t0(but more likely too aggressive) is
clip = tnext > t1 - rtol * (t1 - t0) - atol
Finally, we could also clip (or partition evenly?) if the remaining time step is a small fraction (maybe 1%?) of the previous, e.g.:
clip = tnext > t1 - max(rtol * (t1 - t0) - atol, min_dt_frac * (tnext - tprev))
Maybe sensible definitions could be:
rtol = 100 * mach_eps # = 1.1e-14 for double and 6e-6 for single
atol = min_normal / mach_eps # = 2E-292 for double and 1.8e-31 for single
This should be a pretty good proxy for 100ULP which is approximately proportional to x. Would there be any issues with moving in this direction?
It would also be nice if we can add error checking in ConstantStepSize to ensure that dt0 < rtol * (t1 - t0) + atol and nothing fails silently but this might mean requiring that our times be primitive floats rather than allowing arrays.
Is the problem with the linspace approach that we'd have to commit it to memory which may not be ideal for a simulation with > O(1E6) timesteps? An effectively equivalent way of achieving this behaviour would be to add suport for when dt0 is not provided by taking num_steps=max_steps of size (t1 - t0)/max_steps incrementing by dt each time until the penultimate step and then clipping to t1 at the last step (which we know in advance is step max_steps).
Tagging @aidancrilly for visibility.
Would there be any issues with moving in this direction?
I think my main concern is just that this introduces this dependence on the previous step, and I could see this extra dependence introducing hard-to-{debug, understand} behavior.
Is there a case when you think this would be preferred over just a fixed tolerance of e.g. 100ULPs?
Is the problem with the linspace approach that we'd have to commit it to memory which may not be ideal for a simulation with > O(1E6) timesteps?
Yup.
I imagine we could probably directly find O(1) equivalents if we needed though. Something like jnp.where(step = num_steps, t1, t0 + (t1 - t0) * (step / num_steps)) rather than incrementing by dt0.
Only the fourth option included dependency on tprev so we could leave that out. Both ULP and the rtol approach should scale with t1 meaning that users could manage similar number of timesteps without normalising their times expliciltly. The reason I suggested this is you framed the ULP approach as hacky and I couldn't work out any other jax friendly approach to achieve an actual ULP.
Yes, I like the where approach for the dt=None option!
The reason I suggested this is you framed the ULP approach as hacky
Ah, for this I was specifically framing the float->int conversion as hacky, since it relies on details of the bit representation of the floating-poimt number.
and I couldn't work out any other jax friendly approach to achieve an actual ULP.
So there is jnp.nextafter for moving one ULP in either direction. We have these wrapped into eqx.internal.{prevbefore,nextafter} which additionally offer some safety around denormal numbers (which are frequently not available on accelerators).
Not sure I understand the reason convert to integer. Could we just do this then:
tol = 100.0 * (jnp.nextafter(t1) - jnp.prevbefore(t1))
And keep your current approach?
If so, happy to key in, but otherwise I think I might be missing some important subtleties here and leave with normalising on our side for now.
Nevertheless, the where approach might work for user-specified t1 (even for StepTo if we recompute num_steps for each "stepto") and add num_steps = (t1 - t0) // dt (but this could make the last step almost twice as big as other steps at worst so not ideal).
tol = 100.0 * (jnp.nextafter(t1) - jnp.prevbefore(t1))
So this is probably mostly fine, just a little iffy as compared to e.g.
clip_t = t1
for _ in range(100):
clip_t = prevbefore(clip_t)
due to floating-point rounding.
Not sure I understand the reason convert to integer.
Sorry, this was a comment I didn't make super clearly. I was jumping straight to the above code snippet, and reasoning that the efficient way to represent this a lot of the time would be (pseudocode) bitcast_to_float(bitcast_to_int(t1) - 100).
On reflection I've realised that that's not really true though, this only works if you don't underflow in the mantissa. Just iterating prevbefore is probably fine.
Anyway, it sounds like we're probably converging on two possible changes here:
- The '100 ULPs of tolerance' approach, as an alternative to the current clipping tolerances.
- Adjusting
ConstantStepSizenot to iteratively adddt0.
I'd be happy to consider a PR on each of these.