Different results for dopri5 depending on loop vs solve
I am trying to solve the Lorenz butterfly using diffrax, and in particular the Dopri5 solver.
I take two approaches.
In of them, I simply solve over the entire time range (0,40), using a step size DISCRETE_DT=1e-2
In another, I solve over each time range (t, t+DISCRETE_DT), and then loop this solver.
Annoyingly, the numerical error between these 2 approaches grows with the sequence length?
Does anyone have a good explanation for this phenomenon? I know the Lorenz butterfly is chaotic and that Dopri5 takes adaptive step sizes but
- I'm using float64
- I've checked and no steps are being rejected in either trajectory. So, I would have thought the exact same numerical computations are being used in both settings.
Here is a colab with a repro: https://colab.research.google.com/drive/1aNjUCB4SmQ_UDJ3H5Lqu4jzfFxOwY6Q6?usp=sharing
Hi Xavier,
adaptive step sizes in diffrax are not a solver property. These are controlled by the step size controllers, which are documented here. You're not specifying any, so you default to a constant step size, which is why you do not see any rejected steps.
Now, for the numerical errata - I suspect this is because you're using a regular Python loop, which will hand variables and control back and forth between the JAX and Python programs. You probably accumulate a little bit of error there.
In your notebook, you also mention that the SaveAt buffer is as long as the number of steps - this is by design, and required for static shape compilation with adaptive step sizing. If you want to save the solution at specific steps only, you can use SaveAt(ts=...).
Interestingly, if you change the t's to match along (via
@jax.jit
def step_lorenz(x_curr, _): # numerical error
solution = diffeqsolve(
TERM, SOLVER, t0=0.0, t1=DISCRETE_DT, dt0=DISCRETE_DT, y0=x_curr
)
num_rejected_steps = solution.stats['num_rejected_steps']
return solution.ys.squeeze(), num_rejected_steps
@jax.jit
def step_lorenz(x_curr, t): # no numerical error
solution = diffeqsolve(
TERM, SOLVER, t0=t, t1=t + DISCRETE_DT, dt0=DISCRETE_DT, y0=x_curr
)
num_rejected_steps = solution.stats['num_rejected_steps']
return solution.ys.squeeze(), num_rejected_steps
def rollout_lorenz(x0, num_steps):
trajectory = []
x = x0
num_rejected_steps = 0
t = 0
for _ in range(num_steps):
x, rejected_steps = step_lorenz(x, t)
t += DISCRETE_DT
num_rejected_steps += rejected_steps
trajectory.append(x)
return jnp.array(trajectory), num_rejected_steps
trajectory, num_rejected_steps = rollout_lorenz(init_state, T)
print(f"number of rejected steps was {num_rejected_steps}")
I see no numerical error
That's really interesting @lockwo, I wonder why this is!
Yea, I'm not totally sure. I might start by trying just a solver.step approach and if that also differs then it's in the solver, else it's because diffeqsolve might be doing something with the time (at least that would be my approach to narrowing down the discrepancy)