optimistix icon indicating copy to clipboard operation
optimistix copied to clipboard

Creating an `AbstractMinimiser` with `DampedNewtonDescent`

Open michael-0brien opened this issue 2 months ago • 3 comments

Hello! I'd like to try out a minimizer that is a hybrid first order / second order method like levenberg-marquardt. I am optimizing a noisy scalar loss, but I've found that LM works very well compared to an optax approach like adabelief. In general, the sum of squares is not the best choice of loss for me.

From reading through the code, my understanding is this should be as simple as the following:

import optimistix as optx

class HybridMinimiser(optx.AbstractLBFS):

        rtol: float
        atol: float
        norm: Callable[[PyTree], Scalar]
        use_inverse: bool
        descent: optx.DampedNewtonDescent
        search: optx.ClassicalTrustRegion
        history_length: int
        verbose: frozenset[str]

        def __init__(
            self,
            rtol: float,
            atol: float,
            norm: Callable[[PyTree], Scalar] = max_norm,
            use_inverse: bool = True,
            history_length: int = 10,
            verbose: frozenset[str] = frozenset(),
        ):
            self.rtol = rtol
            self.atol = atol
            self.norm = norm
            self.use_inverse = use_inverse
            self.descent = optx.DampedNewtonDescent()
            self.search = optx.ClassicalTrustRegion()
            self.history_length = history_length
            self.verbose = verbose

My understanding is that such a minimizer would use an LBFGS-like approach for estimating the hessian, then use this for the damped newton step, rather than the Gauss-Newton approximation as in LM. (By the way, this modularity is totally amazing and the developers of optimistix should totally pat themselves on the back!)

I have two questions:

  • To someone with some expertise, I'd love to hear any thoughts if there are any footguns with what is being done here. For example, I don't yet understand what search I should choose for such a method (and am not really familiar with searches in general); in the above, I used the LM ClassicalTrustRegion search, but maybe the BFGS BacktrackingArmijo is better.
  • I am optimizing a loss with respect to 100 or so parameters, but many parameters are independent and my hessian will be block diagonal. Is there a way to indicate this with the tags argument, or would I need to write custom init_hessian and update_hessian methods too?

michael-0brien avatar Oct 31 '25 14:10 michael-0brien

Hi!

and great to hear that you're experimenting with custom solvers :)

First off, as a clarifying question: If you have a scalar loss, why is squaring it undesirable? Due you think that this is due to the noisiness, and would you like to increase robustness (e.g. as one would by minimising absolute distances rather than squared ones)? You can return any scalar loss when using the minimisers, but your loss function should be bounded below, and ideally (locally) convex. In practice that frequently means either squaring the residuals or taking their absolute values.

Now, for the code - yes, that is pretty close! There are some things (+ your questions) to consider here:

Hessian approximation + regularisation

  • Tikhonov-type regularisers (small multiples of the identity) are typically added to the Hessian itself, rather than to its inverse. This has good numerical properties, preventing ill-conditioned matrices from becoming (near-)singular by shifting their eigenvalues just a little bit. If you're just adding a regulariser to the inverse and then applying this operator to the gradient, then you'd simply shift the solution instead of tweaking the eigenvalues.
  • For the reason above, damped_newton_step does not currently support direct regularisation of the inverse Hessian.
  • quasi-Newton Hessian (or inverse Hessian) approximations are self-regularised already: updates are computed such that they remain positive-definite.

If you'd like to use the LBFGS approximation to the Hessian instead, then damped_newton_step should work out of the box (just set inverse=False).

  • I am optimizing a loss with respect to 100 or so parameters, but many parameters are independent and my hessian will be block diagonal. Is there a way to indicate this with the tags argument, or would I need to write custom init_hessian and update_hessian methods too?

Sparsity structure is not preserved by quasi-Newton approximations, since these are generated from inner and outer products of gradients. You could write your own update method that works directly on the true Hessian and regularises that, but Lineax does not yet have support for block diagonal operators. (This could only be leveraged properly once we have decent sparse linear solvers in JAX, though.)

Searches

For a damped Newton direction, a trust region search is a good choice - they are also paired in LevenbergMarquardt. I would not expect the backtracking line search to be very helpful here, although you may try it - it is quite conservative, halving step sizes until a purely gradient based criterion is met, without necessarily expecting that step directions would change based on step lengths. You could try it of course! But since you have had good experiences with trust region based searches, it seems like these would be the way to go.

johannahaffner avatar Oct 31 '25 17:10 johannahaffner

First off, as a clarifying question: If you have a scalar loss, why is squaring it undesirable? Due you think that this is due to the noisiness, and would you like to increase robustness (e.g. as one would by minimising absolute distances rather than squared ones)? You can return any scalar loss when using the minimisers, but your loss function should be bounded below, and ideally (locally) convex. In practice that frequently means either squaring the residuals or taking their absolute values.

In my case, my loss function is the log likelihood. A good choice of likelihood depends on the distribution of the noise in the data; if it is non-gaussian (e.g. it has skewness), a gaussian likelihood (L2 loss) will overfit the noise. In my scientific domain, the underlying noise distribution of our data is not well understood, and empirically we often find choices besides L2 work well.

Even if L2 is a good choice, it is also common for us to optimize likelihoods (or posteriors) marginalized over nuisance parameters. Of course this is a case for a sampler, but we can also arrive at a scalar loss by analytically integrating over unknowns. This yields a wide range of possible loss functions.

For this reason, it is very important to be able to experiment with different scalar losses! I hope this clarifies.

Now, for the code - yes, that is pretty close! There are some things (+ your questions) to consider here:

Hessian approximation + regularisation

  • Tikhonov-type regularisers (small multiples of the identity) are typically added to the Hessian itself, rather than to its inverse. This has good numerical properties, preventing ill-conditioned matrices from becoming (near-)singular by shifting their eigenvalues just a little bit. If you're just adding a regulariser to the inverse and then applying this operator to the gradient, then you'd simply shift the solution instead of tweaking the eigenvalues.
  • For the reason above, damped_newton_step does not currently support direct regularisation of the inverse Hessian.
  • quasi-Newton Hessian (or inverse Hessian) approximations are self-regularised already: updates are computed such that they remain positive-definite.

If you'd like to use the LBFGS approximation to the Hessian instead, then damped_newton_step should work out of the box (just set inverse=False).

Very cool! This is very helpful to understand.

Sparsity structure is not preserved by quasi-Newton approximations, since these are generated from inner and outer products of gradients. You could write your own update method that works directly on the true Hessian and regularises that, but Lineax does not yet have support for block diagonal operators. (This could only be leveraged properly once we have decent sparse linear solvers in JAX, though.)

Got it. This makes sense.

Searches

For a damped Newton direction, a trust region search is a good choice - they are also paired in LevenbergMarquardt. I would not expect the backtracking line search to be very helpful here, although you may try it - it is quite conservative, halving step sizes until a purely gradient based criterion is met, without necessarily expecting that step directions would change based on step lengths. You could try it of course! But since you have had good experiences with trust region based searches, it seems like these would be the way to go.

Excellent! Thanks for the advice. I’m excited to try this out.

michael-0brien avatar Oct 31 '25 17:10 michael-0brien

For this reason, it is very important to be able to experiment with different scalar losses! I hope this clarifies.

Very interesting, thank you for the explanation!

johannahaffner avatar Oct 31 '25 19:10 johannahaffner