Improve Nested AD
The Nested AD used in DiffEqFlux is not ideal because it calls ForwardDiff.gradient/jacobian and Zygote overrides them to compute the Hessian before doing the HVP.
BatchedRoutines.jl has routines to do this efficiently, we should start migrating the code to use that.
Note that this cannot be upstreamed to Zygote, because it requires capturing a different gradient / jacobian call to compute $\frac{\partial}{\partial p}\left(\frac{\partial f}{\partial u}\right)^T v$. Capturing the ForwardDiff calls only allows us to override $\frac{\partial}{\partial u}\left(\frac{\partial f}{\partial u}\right)^T v$.
(I will probably do the migration myself over the summer if no one else picks it up.)
This seems like a good GSoC project.
Update on this. https://github.com/LuxDL/Lux.jl/pull/598 will handle everything automatically, so relying on another package is unnecessary. We just need to change the Zygote.gradient(f, x, ps) calls to Zygote.gradient(::StatefulLuxLayer, x)