optimistix icon indicating copy to clipboard operation
optimistix copied to clipboard

Extracting the hessian at the minimum from the optimization `Solution`

Open michael-0brien opened this issue 10 months ago • 4 comments

Hello! For my work I am using the optimistix for a laplace approximation of a posterior. To do this, I am currently 1) Using an LM solve to minimize residuals and 2) After the optimization, defining a scalar L2 objective function (log likelihood) and computing the hessian at the minimum (fisher information). A very rough schematic of what this looks like is like this:

import equinox as eqx
import optimistix as optx

def residual_loss_fn(y, args):
    ...

def scalar_loss_fn(...):
    ...

@eqx.filter_hessian
def hessian_fn(...):
    ...

y0, args = ...
solver = optx.LevenbergMarquardt(...)
sol = optx.least_squares(residual_loss_fn, solver, ...)

scalar_loss_at_minimum = scalar_loss_fn(sol.value, args)
hessian_at_minimum = hessian_fn(sol.value, args)

I suspect that it would be more robust to leverage the optimization algorithm to build up an estimate of the hessian about the maximum of the loss surface rather than just evaluating the hessian at the maximum itself. I am wondering if there is a way to do this directly from the optimization results in sol---so far I have been looking to see if I can do this using the f_info in the sol.state but haven't been able to figure it out. Also, if it is not possible to do this using the LM results, I'd be interested to hear if it would be possible to do this using a different algorithm (although LM seems to work the best by far for my problem).

I'll note that my optimization is over 5-10 parameters, so it is reasonable to materialize a dense hessian.

michael-0brien avatar Feb 28 '25 16:02 michael-0brien

Hi @mjo22,

this is a routine step in my application as well, I use LM to optimise the parameters of an ODE. I then use the Hessian / Fisher information of the loglikelihood (GLS in my case, skipping over some details), taken at the optimum of the nonlinear solve.

I think this is the best way to do it. You could theoretically access the second-order information from the solvers function information, by accessing the linear operator and getting its matrix representation (lx.AbstractLinearOperator.as_matrix()). However, this matrix is only going to give you an approximation, not the true Hessian. In the case of LM, it is the Jacobian of the residuals, so you would also have to normalise such that you get $J^TJ$, which is equivalent only to the empirical Fisher information (the latter having various drawbacks).

BFGS would give you a square matrix, but this is an approximation in itself, and not equivalent to the full Hessian either (which is actually what gives it some of its favourable properties during optimisation).

What specifically would you like to improve? I'm assuming that you use this for uncertainty quantification (?) - if you're interested in fatter tails, sampling is the way to go. However, if your problem is reasonably well behaved, then the overlap with the quadratic assumption you get out of the mean + Fisher is usually quite good.

johannahaffner avatar Feb 28 '25 16:02 johannahaffner

Thanks so much for the timely response! This is very helpful.

I use the approximation to analytically compute marginal posteriors via gaussian integrals in some of my parameters (so in a sense yes I use this for uncertainty quantification). I haven't ruled out trying sampling, but at this stage it is best to first try laplace; my problem is very messy as I am mapping out a multi-modal likelihood surface with good initial guesses for each peak location. There could be around ~1000 peaks, each of which I find using an LM solve.

It is interesting to hear that my current method may be the best option. The reason why I suspect it could be beneficial to not a pointwise estimate of the hessian is that my likelihood surface is very noisy. I have a few questions, whenever you get the time:

  • To be honest I am still getting familiar with the details of LM; my impression is that the jacobian stored in the Solution is an estimate built up over the course of the algorithm, rather than the value simply at the minimum. Is this true? If so, roughly how is the jacobian returned from the LM solve computed / is there a reference that would be enlightening here?
  • I am not totally familiar with estimation theory for the fisher information and the pitfalls of this compared to simply using autodiff at the solution. Are there any references you like that could be helpful for me (e.g. on formulating the empirical fisher information and its pitfalls, successes of JAX autodiff for the laplace approx, etc)?
  • It's a little unclear to me at this stage what will be better, so I may want to try both and run some tests. Would you happen to be able to help out with some pseudocode for formulating the estimator? Would it be as simple as the following?
sol = ...
jac_operator = sol.state.f_info.jac

@eqx.filter_jit
def compute_hessian(jac_operator):
    jac_matrix = jac_operator.as_matrix()  # shape (n_residuals, n_parameters)?
    return jac_matrix.T @ jac_matrix

hessian_estimator = compute_hessian(jac_operator)

I suppose another option is to return the true hessian at every optimization step and return it as auxiliary information in my compute_residuals_fn; then I suppose I could formulate an estimator from this information. However, I suspect the slowdown from returning the hessian could not be worth it.

michael-0brien avatar Feb 28 '25 17:02 michael-0brien

Glad to hear that it helps! In order:

my likelihood surface is very noisy.

Is the noise homoscedastic or does it have a pattern you can account for (in generalised least squares)?

  • my impression is that the jacobian stored in the Solution is an estimate built up over the course of the algorithm, rather than the value simply at the minimum. Is this true?

No, this is not true. The Jacobian is generated at each iterate y_eval, if the solver terminates, all returned values are those at the last (successful) y_eval. (Computation of the Jacobian happens in the function called here.)

In BFGS, the Hessian is indeed built up over time, using the gradients taken at the accepted iterates. (This happens here.)

  • Would [computing the Hessian approximation] be as simple as the following? ...

Yes, but I doubt that it would be very useful. Which brings me to the next point

  • I am not totally familiar with estimation theory for the fisher information and the pitfalls of this compared to simply using autodiff at the solution.

The key here is to think about what these matrices will encode. The Jacobian of the residuals is a first order approximation that only considers how each residual, assumed to be independent of its neighbours (unless you account for their dependency pattern in your residual function), is affected by a change in the parameters. By normalising this to get $J^TJ$, you are making the assumption that you can approximate second order interactions of the parameters with the product of the derivatives, i.e. that

$$ \partial_{p_1} [\partial_{p_2} [f]] \approx \partial_{p_1} [f] \cdot \partial_{p_2} [f] $$

which is a strong assumption to make, especially since you're hoping that the first derivatives tend to zero at the optimum.

In BFGS, we do build up the Hessian over time. Its derivation looks a lot like a finite difference approximation to a (second-order) derivative, but instead of taking the limit over very small intervals, we select a finite, not-to-small interval so that we can make good progress in our optimisation. As we converge, these intervals will naturally get smaller, so the approximation improves. It will carry some amount of historical information, which brings me to your motivation - it sounds like you want to compensate for the noisiness of your likelihood surface, by taking into account more than one point. I think there is a pitfall there, which is that these solvers are all deterministic. They do not explore, the way a sampler would. The historical information you would get is simply whatever is accumulated in the approach to that one point, not necessarily what would happen one mini-peak over. And we can also not be sure that the Hessian approximation that does carry a little bit of history is more conservative in any sense, such as leading to a wider distribution.

(I do think that you could warm-start a sampler with your LM optimum, and any Hessian - pretty sure you can supply these in blackjax window adaptation for NUTS (?).)

You also might want to take a look at BestSoFarLeastSquares, which did improve the values I got a little bit.

  • Are there any references you like that could be helpful for me (e.g. on formulating the empirical fisher information and its pitfalls, successes of JAX autodiff for the laplace approx, etc)?

I'm not aware of a succinct modern thing, I learned this from older stats literature. Do share if you find something!

I suppose another option is to return the true hessian at every optimization step and return it as auxiliary information in my compute_residuals_fn; then I suppose I could formulate an estimator from this information. However, I suspect the slowdown from returning the hessian could not be worth it.

The initial approximations will also not be very good.

johannahaffner avatar Mar 03 '25 18:03 johannahaffner

Sorry for the late reply. This is very helpful and makes sense, particularly with regard to how to think about the noisy loss surface.

michael-0brien avatar May 14 '25 14:05 michael-0brien