diffrax
diffrax copied to clipboard
Possible performance issue with implicit solvers
Hello @patrick-kidger, I was investigating the implicit solvers of diffrax for a possible integration into dynamiqs. However, I am seeing that the performance of Kvaerno3
and Kvaerno5
is highly dependent on the choice of atol
and rtol
. Is this expected? Or could this be a bug somewhere?
Here is a MWE, for a typical stiff ODE we want to solve:
import diffrax as dx
import jax.numpy as jnp
# simulation parameters
N = 16
alpha = 2.0
T = 1.0
# quantum operators
a = jnp.diag(jnp.sqrt(jnp.arange(1, N)), 1)
i = jnp.eye(N)
L = jnp.linalg.matrix_power(a, 4) - alpha**4 * i
Lt = L.T
LtL = Lt @ L
# vector field
def vector_field(t, y, _):
return L @ y @ Lt - 0.5 * (LtL @ y + y @ LtL)
# initial state
y0 = jnp.zeros((N, N))
y0 = y0.at[0, 0].set(1.0)
# define solver
tsave = jnp.linspace(0.0, T, 100)
solve = lambda solver, tol: dx.diffeqsolve(
dx.ODETerm(vector_field),
solver,
t0=0.0,
t1=T,
dt0=0.01,
y0=y0,
saveat=dx.SaveAt(ts=tsave),
stepsize_controller=dx.PIDController(rtol=tol, atol=tol),
max_steps=100_000,
progress_meter=dx.TqdmProgressMeter(),
)
# test several explicit and implicit solvers with different tolerences
solver = dx.Tsit5()
num_steps = solve(solver, tol=1e-4).stats["num_steps"] # 4037
num_steps = solve(solver, tol=1e-5).stats["num_steps"] # 4039
num_steps = solve(solver, tol=1e-6).stats["num_steps"] # 4039
solver = dx.Kvaerno5()
num_steps = solve(solver, tol=1e-4).stats["num_steps"] # 12
num_steps = solve(solver, tol=1e-5).stats["num_steps"] # 3788
num_steps = solve(solver, tol=1e-6).stats["num_steps"] # 11902
solver = dx.Kvaerno3()
num_steps = solve(solver, tol=1e-4).stats["num_steps"] # 25
num_steps = solve(solver, tol=1e-5).stats["num_steps"] # 561
num_steps = solve(solver, tol=1e-6).stats["num_steps"] # 27031
As you can see, for Tsit5
, the number of steps (and thus the solver performance) is very stable with increasing tolerences. However, for Kvaerno5
and Kvaerno3
, increasing the tolerence by two orders of magnitudes makes the solver completely explode.