fmmax icon indicating copy to clipboard operation
fmmax copied to clipboard

Use "differentiable optimization trick" for backpropagation through tangent vector field calculation

Open mfschubert opened this issue 1 year ago • 5 comments

Currently, we directly backpropagate through the tangent vector field calculation, which involves a Newton solve to find the minimum of a convex quadratic objective. It may be more efficient to define a custom gradient for this operation, in a manner similar to what is done for differentiable optimization.

mfschubert avatar Jan 10 '24 17:01 mfschubert

I am seeing some issues with super-long compile times in the optimization context, which are eliminated when we use a stop_gradient before the vector field calculation. I am thinking we should just add this stop_gradient, and then restore the ability to backpropagate through vector field generation via the method mentioned above. This might be fairly involved, and would take time. fyi @smartalecH

mfschubert avatar Jan 10 '24 21:01 mfschubert

Yep this sounds like a good plan to me. How hard do we anticipate the manual adjoint will be?

smartalecH avatar Jan 10 '24 22:01 smartalecH

I am looking at it a bit. It might actually be relatively straightforward. Here's a reference that seems nice, it even includes Jax code: https://implicit-layers-tutorial.org/implicit_functions/

mfschubert avatar Jan 10 '24 22:01 mfschubert

@smartalecH @Luochenghuang I have things working here---all it needed was a bit of regularization.

https://github.com/mfschubert/mewtax

mfschubert avatar Feb 09 '24 22:02 mfschubert

I think we may want to put this on hold for now: the potential accuracy improvement is small, and there is a speed penalty.

  • I added a test with #94 which checks the FD gradient against AD gradient. They are very close as-is, i.e. even with the stop_gradient in the vector field calculation.
  • I tested using mewtax to solve for the vector fields, but this seems to make the tests much slower (2x time to complete all tests). I suspect there is a significant compile time penalty.

mfschubert avatar Feb 28 '24 18:02 mfschubert