DiffEqFlux.jl icon indicating copy to clipboard operation
DiffEqFlux.jl copied to clipboard

Improve Nested AD

Open avik-pal opened this issue 1 year ago • 2 comments

The Nested AD used in DiffEqFlux is not ideal because it calls ForwardDiff.gradient/jacobian and Zygote overrides them to compute the Hessian before doing the HVP.

BatchedRoutines.jl has routines to do this efficiently, we should start migrating the code to use that.

Note that this cannot be upstreamed to Zygote, because it requires capturing a different gradient / jacobian call to compute $\frac{\partial}{\partial p}\left(\frac{\partial f}{\partial u}\right)^T v$. Capturing the ForwardDiff calls only allows us to override $\frac{\partial}{\partial u}\left(\frac{\partial f}{\partial u}\right)^T v$.

(I will probably do the migration myself over the summer if no one else picks it up.)

avik-pal avatar Mar 16 '24 16:03 avik-pal

This seems like a good GSoC project.

ChrisRackauckas avatar Mar 18 '24 02:03 ChrisRackauckas

Update on this. https://github.com/LuxDL/Lux.jl/pull/598 will handle everything automatically, so relying on another package is unnecessary. We just need to change the Zygote.gradient(f, x, ps) calls to Zygote.gradient(::StatefulLuxLayer, x)

avik-pal avatar Apr 22 '24 21:04 avik-pal