diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

Possible performance issue with implicit solvers

Open gautierronan opened this issue 8 months ago • 2 comments

Hello @patrick-kidger, I was investigating the implicit solvers of diffrax for a possible integration into dynamiqs. However, I am seeing that the performance of Kvaerno3 and Kvaerno5 is highly dependent on the choice of atol and rtol. Is this expected? Or could this be a bug somewhere?

Here is a MWE, for a typical stiff ODE we want to solve:

import diffrax as dx
import jax.numpy as jnp

# simulation parameters
N = 16
alpha = 2.0
T = 1.0

# quantum operators
a = jnp.diag(jnp.sqrt(jnp.arange(1, N)), 1)
i = jnp.eye(N)
L = jnp.linalg.matrix_power(a, 4) - alpha**4 * i
Lt = L.T
LtL = Lt @ L

# vector field
def vector_field(t, y, _):
    return L @ y @ Lt - 0.5 * (LtL @ y + y @ LtL)

# initial state
y0 = jnp.zeros((N, N))
y0 = y0.at[0, 0].set(1.0)

# define solver
tsave = jnp.linspace(0.0, T, 100)
solve = lambda solver, tol: dx.diffeqsolve(
    dx.ODETerm(vector_field),
    solver,
    t0=0.0,
    t1=T,
    dt0=0.01,
    y0=y0,
    saveat=dx.SaveAt(ts=tsave),
    stepsize_controller=dx.PIDController(rtol=tol, atol=tol),
    max_steps=100_000,
    progress_meter=dx.TqdmProgressMeter(),
)

# test several explicit and implicit solvers with different tolerences
solver = dx.Tsit5()
num_steps = solve(solver, tol=1e-4).stats["num_steps"] # 4037
num_steps = solve(solver, tol=1e-5).stats["num_steps"] # 4039
num_steps = solve(solver, tol=1e-6).stats["num_steps"] # 4039

solver = dx.Kvaerno5()
num_steps = solve(solver, tol=1e-4).stats["num_steps"] # 12
num_steps = solve(solver, tol=1e-5).stats["num_steps"] # 3788
num_steps = solve(solver, tol=1e-6).stats["num_steps"] # 11902

solver = dx.Kvaerno3()
num_steps = solve(solver, tol=1e-4).stats["num_steps"] # 25
num_steps = solve(solver, tol=1e-5).stats["num_steps"] # 561
num_steps = solve(solver, tol=1e-6).stats["num_steps"] # 27031

As you can see, for Tsit5, the number of steps (and thus the solver performance) is very stable with increasing tolerences. However, for Kvaerno5 and Kvaerno3, increasing the tolerence by two orders of magnitudes makes the solver completely explode.

gautierronan avatar May 30 '24 15:05 gautierronan