pymc-experimental icon indicating copy to clipboard operation
pymc-experimental copied to clipboard

Adjoint method to find the gradient of the Laplace approximation/mode

Open theorashid opened this issue 9 months ago • 3 comments

This is part of INLA roadmap #340.

From the Stan paper:

One of the main bottlenecks is differentiating the estimated mode, $\theta^* $. In theory, it is straightforward to apply automatic differentiation, by bruteforce propagating derivatives through $\theta^* $, that is, sequentially differentiating the iterations of a numerical optimizer, But this approach, termed the direct method, is prohibitively expensive. A much faster alternative is to use the implicit function theorem. Given any accurate numerical solver, we can always use the implicit function theorem to get derivatives. One side effect is that the numerical optimizer is treated as a black box. By contrast, Rasmussen and Williams [34] define a bespoke Newton method to compute $\theta^* $, meaning we can store relevant variables from the final Newton step when computing derivatives. In our experience, this leads to important computational savings. But overall this method is much less flexible, working well only when the number of hyperparameters is low dimensional and requiring the user to pass the tensor of derivatives.

I think the jax implementation uses the tensor of derivatives but not 100% sure.

theorashid avatar May 15 '24 13:05 theorashid