AbstractReversibleSolver + ReversibleAdjoint
Re-opening #593.
Implements AbstractReversibleSolver base class and ReversibleAdjoint for reversible back propagation.
This updates SemiImplicitEuler, LeapfrogMidpoint and ReversibleHeun to subclass AbstractReversibleSolver.
Implementation
AbstractReversibleSolver subclasses AbstractSolver and adds a backward_step method:
@abc.abstractmethod
def backward_step(
self,
terms: PyTree[AbstractTerm],
t0: RealScalarLike,
t1: RealScalarLike,
y1: Y,
args: Args,
solver_state: _SolverState,
made_jump: BoolScalarLike,
) -> tuple[Y, DenseInfo, _SolverState]:
This method should reconstruct y0, solver_state at t0 from y1, solver_state at t1. See the aforementioned solvers for examples.
When backpropagating, ReversibleAdjoint uses this backward_step to reconstruct state. We then take a vjp through a local forward step and accumulate gradients.
ReversibleAdjoint now also pulls back gradients from any interpolated values, so we can use SaveAt(ts=...)!
We allow arbitrary solver_state (provided it can be reconstructed reversibly) and calculate gradients w.r.t. solver_state. Finally, we pull back these gradients onto y0, args, terms using the solver.init method.
I've also added the Reversible RK solvers here which just subclasses AbstractReversibleAdjoint. Let me know what you think of this and I can add some documentation when it's good to go!
Heads-up that I've just updated the base branch to dev. It looks like there are a number of old commits sitting around on this PR, likely from where this branch forked off of main. You should be able to resolve these by first (a) squashing all the commits that actually belong on this branch together, and then (b) rebasing that new single commit on top of dev.
(Unrelatedly, lmk when this branch is ready for review.)
I think I've now addressed all of your comments, so it should be ready for review 👍
Understanding how to rebase through multiple merges was an experience but I believe that is correct now...
(No pressure to review anytime soon Patrick, I just marked this as "review requested" so it's easy for you to see across your sprawling jax empire)