Question about different adjoint methods
Just wanted to ask if I have the right understanding of what algorithms are being used for each adjoint method, as the documentation does not link each method to a paper or reference.
My understanding is: RecursiveCheckpointAdjoint - Autodiff through ODE solver, uses checkpointing DirectAdjoint - Autodiff through ODE solver, doesn’t use checkpointing BacksolveAdjoint - Continuous-time adjoint equation solver, doesn’t seem to use checkpointing ForwardMode - Some sort of forward sensitivity method, either in continuous or discrete time (unclear) ImplicitAdjoint - Optimize trajectory y (written as a discrete set of points) to set ODE RHS to zero
Is this correct? The one I'm most confused about is what ForwardMode is.
So RecursiveCheckpointAdjoint does already have some fairly extensive bibtex citations in its documentation. These are also displayed under diffrax.citation(..., adjoint=RecursiveCheckpointAdjoint()). In particular it uses Stumm--Walther--Wang--Moin--Iaccarino recursive checkpointing over unbounded horizons.
DirectAdjoint does use checkpointing, but it's an alternate approach involving multiple nested scan-checkpoints. This is a JAX hack that I cooked up for a few compatibility cases, nothing more.
BacksolveAdjoint is indeed optimise-then-discretise reverse mode. Still a terrible method that got popular for no good reason. We could probably include some extra references for this if we wanted but they're pretty numerous.
ForwardMode is discretise-then-optimise forward mode. Good catch that it's not documented which, I've just pushed a doc update to include that.
ImplicitAdjoint requires that you solve the ODE to steady state. It then computes gradients with respect to your inputs just like the other methods (y0, args, ...), whilst exploiting the fact that the result is a steady state to use a more computationally cheap adjoint method, namely the implicit function theorem. (So in particular it's not about optimising the trajectory to get the RHS to zero, it is requiring that you already be solving your ODE to zero.)
I hope that helps!
Thanks Patrick this clarifies it!
So just to double check, ForwardMode is very similar to RecursiveCheckpointAdjoint and DirectAdjoint (which are all discretize-then-optimize approaches), but it is using forward mode AD whereas the other two are using reverse mode AD. So I for example, would not want to train a neural ODE with ForwardMode, but perhaps for a small parameterized ODE this would be a good option. Does ForwardMode also do some form of checkpointing?
Yes exactly.
There's no checkpointing in ForwardMode. Indeed I usually think of checkpointing as being a concept associated only with reverse-mode autodiff.