diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

AbstractReversibleSolver + ReversibleAdjoint

Open sammccallum opened this issue 10 months ago • 4 comments

Re-opening #593.

Implements AbstractReversibleSolver base class and ReversibleAdjoint for reversible back propagation.

This updates SemiImplicitEuler, LeapfrogMidpoint and ReversibleHeun to subclass AbstractReversibleSolver.

Implementation

AbstractReversibleSolver subclasses AbstractSolver and adds a backward_step method:

@abc.abstractmethod
def backward_step(
    self,
    terms: PyTree[AbstractTerm],
    t0: RealScalarLike,
    t1: RealScalarLike,
    y1: Y,
    args: Args,
    solver_state: _SolverState,
    made_jump: BoolScalarLike,
) -> tuple[Y, DenseInfo, _SolverState]:

This method should reconstruct y0, solver_state at t0 from y1, solver_state at t1. See the aforementioned solvers for examples.

When backpropagating, ReversibleAdjoint uses this backward_step to reconstruct state. We then take a vjp through a local forward step and accumulate gradients.

ReversibleAdjoint now also pulls back gradients from any interpolated values, so we can use SaveAt(ts=...)!

We allow arbitrary solver_state (provided it can be reconstructed reversibly) and calculate gradients w.r.t. solver_state. Finally, we pull back these gradients onto y0, args, terms using the solver.init method.

sammccallum avatar Mar 14 '25 12:03 sammccallum

I've also added the Reversible RK solvers here which just subclasses AbstractReversibleAdjoint. Let me know what you think of this and I can add some documentation when it's good to go!

sammccallum avatar Mar 14 '25 12:03 sammccallum

Heads-up that I've just updated the base branch to dev. It looks like there are a number of old commits sitting around on this PR, likely from where this branch forked off of main. You should be able to resolve these by first (a) squashing all the commits that actually belong on this branch together, and then (b) rebasing that new single commit on top of dev.

(Unrelatedly, lmk when this branch is ready for review.)

patrick-kidger avatar Jun 16 '25 22:06 patrick-kidger

I think I've now addressed all of your comments, so it should be ready for review 👍

Understanding how to rebase through multiple merges was an experience but I believe that is correct now...

sammccallum avatar Oct 12 '25 15:10 sammccallum

(No pressure to review anytime soon Patrick, I just marked this as "review requested" so it's easy for you to see across your sprawling jax empire)

sammccallum avatar Nov 19 '25 22:11 sammccallum