SciMLSensitivity.jl icon indicating copy to clipboard operation
SciMLSensitivity.jl copied to clipboard

Adding A New AD

Open willtebbutt opened this issue 1 year ago • 4 comments

Question❓

I'm interested in getting (the soon-to-be-renamed) Tapir.jl to work nicely with SciMLSensitivity.jl, as I understand this to be the way to get it to play nicely with the SciML ecosystem more broadly (please correct me if I'm wrong on this point!)

I can't see from the docs, or from a quick dive into some of the internals of this package, what the right way to go about this is. It looks like there's a few functions that I need to add methods for, but I'm not at all sure.

Could someone point me in the right direction?

willtebbutt avatar Sep 05 '24 15:09 willtebbutt

You'd just define a _vecjacobian! overload: https://github.com/SciML/SciMLSensitivity.jl/blob/v7.66.2/src/derivative_wrappers.jl#L656-L752

along with the caches: https://github.com/SciML/SciMLSensitivity.jl/blob/v7.66.2/src/adjoint_common.jl#L212-L215

where that's used to build whatever caches are required to get a non-allocating vecjacobian.

You would probably want to opt it into callbacks https://github.com/SciML/SciMLSensitivity.jl/blob/v7.66.2/src/callback_tracking.jl because it supports mutation.

Similarly, GaussAdjoint needs a special cache https://github.com/SciML/SciMLSensitivity.jl/blob/eded161bf72dc6815f64bac6756d5afdd78881d5/src/gauss_adjoint.jl#L416-L432 and overload https://github.com/SciML/SciMLSensitivity.jl/blob/eded161bf72dc6815f64bac6756d5afdd78881d5/src/gauss_adjoint.jl#L496-L504

and QuadratureAdjoint needs a fairly similar one https://github.com/SciML/SciMLSensitivity.jl/blob/eded161bf72dc6815f64bac6756d5afdd78881d5/src/quadrature_adjoint.jl#L213-L239 and https://github.com/SciML/SciMLSensitivity.jl/blob/eded161bf72dc6815f64bac6756d5afdd78881d5/src/quadrature_adjoint.jl#L293-L300

ChrisRackauckas avatar Sep 08 '24 01:09 ChrisRackauckas

Lovely, thanks for this. Is there a recommended way to test that my overloads are implemented correctly?

willtebbutt avatar Sep 11 '24 12:09 willtebbutt

You'll see that this file runs through lots of combinations:

https://github.com/SciML/SciMLSensitivity.jl/blob/master/test/adjoint.jl#L46-L196

If you add a vjp option for Tapir then you should be able to just add an autojacvec = AutoTapir() and add a test that it's matching. Then there's like 5 sets of them.

Because this grew organically over time it's a bit verbose, but we'll make that into a loop some day 😅

ChrisRackauckas avatar Sep 11 '24 13:09 ChrisRackauckas

Haha excellent -- thanks for the info!

willtebbutt avatar Sep 11 '24 13:09 willtebbutt