Question about an “AllSaveAdjoint” (store-all-steps, no recomputation) variant
Thanks for the amazing project. I find it very useful and blazing fast in quantum simulation, but do not know the implementation details about the solvers. Here is how I trying to understand how these solver working
Currently Diffrax exposes several adjoint/backpropagation strategies: RecursiveCheckpointAdjoint (default, binomial checkpointing), DirectAdjoint (forward+reverse capable but less efficient), and BacksolveAdjoint (continuous adjoint; not recommended due to approximate gradients). I did not find an option that simply stores every intermediate solver state once and performs a single reverse sweep with zero recomputation (time O(n), memory O(n)); the closest seems to be increasing the number of checkpoints in RecursiveCheckpointAdjoint.
Questions / motivation:
- Is there a deliberate design reason for not exposing a convenience variant that forces saving every step (e.g.
AllSaveAdjoint), given that users with ample memory might prefer pure O(n) time without the logn factor? - Can setting checkpoints to a very large number (e.g. equal to or exceeding actual step count) already emulate “all-save”, or are there structural evictions / internal policies that still cause recomputation even when memory allows?
Use cases:
- Small-to-moderate problems where n is not huge (such as quantum simulation of small quantum systems, this is basically all we can do in classic computer, which means this may be the full use case), GPU memory is plentiful, and wall-clock latency per training step matters more than conserving RAM.
- Benchmarking: having a no-recompute baseline helps quantify the overhead of checkpoint heuristics or of
DirectAdjoint. - Situations combining reverse-mode gradients with experimentation on solver internals where deterministic “no replay” profiling is simpler.
I'm new to differential equation solvers, if there are any misunderstandings please let me know, I'd appreciate it :)
I did not find an option that simply stores every intermediate solver state once and performs a single reverse sweep with zero recomputation (time O(n), memory O(n)); the closest seems to be increasing the number of checkpoints in RecursiveCheckpointAdjoint.
This is just RecursiveCheckpointAdjoint with checkpoints=max_steps right? I guess this is just a generalization of the strict reverse mode/no checkpointing (one could make a wrapper that does the above). You can also see and make sure every step is being checkpointed by using the jax tooling discussed in https://github.com/patrick-kidger/diffrax/issues/564
I have tested RecursiveCheckpointAdjoint with checkpoints=max_steps with following simple code
def vec_field(t, y, args):
return -0.1 * y + 0.1 * jnp.sin(0.2 * t)
term = dfx.ODETerm(vec_field)
t0 = 0
T = 8000
y_size = 2250
y0 = jnp.ones((y_size,))
solver = dfx.Tsit5()
def f(y0):
return dfx.diffeqsolve(
term,
solver,
t0,
T,
dt0=0.1,
y0=y0,
max_steps=100000,
adjoint=dfx.RecursiveCheckpointAdjoint(checkpoints=None),
)
grad_solve = jax.grad(lambda x: jnp.sum(f(x).ys[-1]))
a = grad_solve(y0)
jax.block_until_ready(a)
print_saved_residuals(f, y0)
If I set checkpoints=100000, which is the max_steps, final time cost is about 50 seconds, memory cost is over 2GB.
If I set checkpoints=None, which means to use binomial checkpointing, final time cost is about 70 seconds, memory cost is about 200MB.
And when set checkpoints=100000, there are two f32[100000,2250] saved during forward.
When set checkpoints=None, there are two f32[446,2250] saved during forward.
This meets expectations.
It do seems that
setting checkpoints to a very large number (e.g. equal to or exceeding actual step count) already emulate “all-save”
But is it exactly “all-save”? Or is it not necessory to consider the difference?
By the way, there is overhead when set checkpoints=max_steps in memory, beacuse the real steps may be less than max_steps.
I made a small demo without using diffeqsolve
@partial(jax.jit, static_argnums=(1,))
def f(y0, max_steps: int = 100_000):
# initialize carry
state0 = solver.init(term, t0, t0 + dt0, y0, None)
carry0 = (y0, state0, t0, False)
@jax.checkpoint
def step_fn(carry):
y, state, t_prev = carry
t_next = jnp.minimum(t_prev + dt0, T)
y_new, _, _, state_new, _ = solver.step(
term, t_prev, t_next, y, None, state, made_jump=False
)
return (y_new, state_new, t_next)
@jax.jit
def body(carry, _):
y, state, t_prev, done = carry
(y_new, state_new, t_next) = jax.lax.cond(
~done,
step_fn,
lambda *_: (y, state, t_prev),
(y, state, t_prev),
)
done_new = done | (t_next >= T)
return (y_new, state_new, t_next, done_new), None
(y_final, _, _, _), _ = jax.lax.scan(body, carry0, xs=None, length=max_steps)
return y_final
grad_solve = jax.grad(lambda x: jnp.sum(f(x)))
a = grad_solve(y0)
jax.block_until_ready(a)
print_saved_residuals(f, y0)
Time cost is similar to checkpoints=100000, which is about 50 seconds.
Hi @littlebaker! First of all, your description of how things work at the moment is absolutely spot-on. Indeed RecursiveCheckpointAdjoint only records the values at the end of each numerical step, and recomputes the internal operations during its backward pass (in addition to any binomial checkpointing structure over the loop as a whole).
If you set the number of checkpoints to a large number (>=max_steps) then you should find that the end of every step is checkpointed and there is no binomial structure, but the internals of each step are still recomputed as above.
Now for reference, the underlying binomial checkpointing algorithm is implemented here:
https://github.com/patrick-kidger/equinox/blob/main/equinox/internal/_loop/checkpointed.py
And the reason we don't also implement the save-everything approach is simply that this code is quite complicated enough already! I had to pick one of the two approaches when implementing this and the recomputation one was easier, and arguably a safer chioce as a default. (As a side note, to my knowledge this is the only implementation of variable-length recursive checkpointing anywhere in modern autodiff libraries.)