SciMLSensitivity.jl
SciMLSensitivity.jl copied to clipboard
Adding A New AD
Question❓
I'm interested in getting (the soon-to-be-renamed) Tapir.jl to work nicely with SciMLSensitivity.jl, as I understand this to be the way to get it to play nicely with the SciML ecosystem more broadly (please correct me if I'm wrong on this point!)
I can't see from the docs, or from a quick dive into some of the internals of this package, what the right way to go about this is. It looks like there's a few functions that I need to add methods for, but I'm not at all sure.
Could someone point me in the right direction?
You'd just define a _vecjacobian! overload: https://github.com/SciML/SciMLSensitivity.jl/blob/v7.66.2/src/derivative_wrappers.jl#L656-L752
along with the caches: https://github.com/SciML/SciMLSensitivity.jl/blob/v7.66.2/src/adjoint_common.jl#L212-L215
where that's used to build whatever caches are required to get a non-allocating vecjacobian.
You would probably want to opt it into callbacks https://github.com/SciML/SciMLSensitivity.jl/blob/v7.66.2/src/callback_tracking.jl because it supports mutation.
Similarly, GaussAdjoint needs a special cache https://github.com/SciML/SciMLSensitivity.jl/blob/eded161bf72dc6815f64bac6756d5afdd78881d5/src/gauss_adjoint.jl#L416-L432 and overload https://github.com/SciML/SciMLSensitivity.jl/blob/eded161bf72dc6815f64bac6756d5afdd78881d5/src/gauss_adjoint.jl#L496-L504
and QuadratureAdjoint needs a fairly similar one https://github.com/SciML/SciMLSensitivity.jl/blob/eded161bf72dc6815f64bac6756d5afdd78881d5/src/quadrature_adjoint.jl#L213-L239 and https://github.com/SciML/SciMLSensitivity.jl/blob/eded161bf72dc6815f64bac6756d5afdd78881d5/src/quadrature_adjoint.jl#L293-L300
Lovely, thanks for this. Is there a recommended way to test that my overloads are implemented correctly?
You'll see that this file runs through lots of combinations:
https://github.com/SciML/SciMLSensitivity.jl/blob/master/test/adjoint.jl#L46-L196
If you add a vjp option for Tapir then you should be able to just add an autojacvec = AutoTapir() and add a test that it's matching. Then there's like 5 sets of them.
Because this grew organically over time it's a bit verbose, but we'll make that into a loop some day 😅
Haha excellent -- thanks for the info!