fmmax
fmmax copied to clipboard
Use "differentiable optimization trick" for backpropagation through tangent vector field calculation
Currently, we directly backpropagate through the tangent vector field calculation, which involves a Newton solve to find the minimum of a convex quadratic objective. It may be more efficient to define a custom gradient for this operation, in a manner similar to what is done for differentiable optimization.
I am seeing some issues with super-long compile times in the optimization context, which are eliminated when we use a stop_gradient before the vector field calculation. I am thinking we should just add this stop_gradient, and then restore the ability to backpropagate through vector field generation via the method mentioned above. This might be fairly involved, and would take time. fyi @smartalecH
Yep this sounds like a good plan to me. How hard do we anticipate the manual adjoint will be?
I am looking at it a bit. It might actually be relatively straightforward. Here's a reference that seems nice, it even includes Jax code: https://implicit-layers-tutorial.org/implicit_functions/
@smartalecH @Luochenghuang I have things working here---all it needed was a bit of regularization.
https://github.com/mfschubert/mewtax
I think we may want to put this on hold for now: the potential accuracy improvement is small, and there is a speed penalty.
- I added a test with #94 which checks the FD gradient against AD gradient. They are very close as-is, i.e. even with the
stop_gradientin the vector field calculation. - I tested using
mewtaxto solve for the vector fields, but this seems to make the tests much slower (2x time to complete all tests). I suspect there is a significant compile time penalty.