optimistix icon indicating copy to clipboard operation
optimistix copied to clipboard

Scale-invariant Levenberg-Marquardt

Open michael-0brien opened this issue 2 months ago • 13 comments

Hi optimistix developers! I was hoping to start a discussion about improving optimistix support for working in "unnormalized" parameter spaces. Let me explain what I mean by this.

I come from a scientific domain where the task is to infer the parameters of a physical model from messy experimental data (think gradient-based optimization of a loss function with respect to the parameters of a forward model for the data). In these tasks, it is very important to define your optimizer starting with some heuristics about your parameter space.

Arguably the most important / most common thing needed is to define what a "unit" step size is. Physical parameters vary on vastly different scales; one reason is due to the fact that parameters have units, another is simply due to the fact that loss function geometry is non-trivial. This is related to the success of the MCMC sampler implemented in emcee in the physical sciences, which does not care about these factors due to its affine invariance property.

I was generally wondering your thoughts on this, and also hoping you may accept a PR adding support for defining "unit" step size in the appropriate cases. In my fork of optimistix, I've implemented three things that are useful for me in this regard:

  • ClassicalTrustRegions should allow users to specify an initial_step_size, rather than always defaulting to 1.0. In the coming weeks I can try to put together a MRE why this is important, but I believe in unnormalized parameter spaces convergence should be much faster if this can be tuned.
  • DampedNewtonDescent should define the Levenberg-Marquardt parameter as unit_step_size / step_size, rather than 1 / step_size. 1 may not be meaningful in the case I describe.
  • Solvers should allow users to optionally pass their respective AbstractSearch instances. I want to have the option to change ClassicalTrustRegion parameters. This also seems appropriate for BacktrackingArmijo in the BFGS solvers.

These are the changes that would be helpful for my work, but there may other changes that would add support for the general theme I am describing.

Note: I understand that there are other ways of handling the cases I describe, such as applying parameter rescalings directly to my model. I believe what I'm proposing is closely related to the ability to tune "learning_rate" in optax-like optimizers, which is API that many people are accustomed to (and for me is much easier syntaxwise than applying rescalings). It would be good to give users some flexibility in this regard.

michael-0brien avatar Nov 07 '25 20:11 michael-0brien

I was reading up on some literature on this problem and have been looking through scipy’s implementation of LM. It is based on the work of J. J. More, “The Levenberg-Marquardt Algorithm: Implementation and Theory”: https://www.osti.gov/servlets/purl/7256021-WWC9hw/

Here, a procedure is described that makes Levenberg-Marquardt scale invariant (see section 6). This is implemented in scipy via MINPACK so of course has been a cornerstone of scientific optimization for decades.

The jist of this procedure (if you are not familiar already!) is to not only update the LM parameter (or the trust region) but also to update a matrix that takes into account the scale invariance of the problem. For the $k$th iteration, the regularisation looks like $H_k + \lambda_k D_k^T D_k$, where $D_k$ is typically chosen to be a diagonal matrix.

Rather than the minimal change I suggested above of allowing specification of a “unit” step size, I think it would be great to implement this given the algorithm’s track record. I am curious to hear thoughts on this, if there’s any interest in helping me out, and what the developers would like in terms of API (if this would be accepted). I could also use some help from someone with more expertise than myself in parsing through the literature!

michael-0brien avatar Nov 08 '25 04:11 michael-0brien

I was hoping to start a discussion about improving optimistix support for working in "unnormalized" parameter spaces.

Very happy to have one! There have already been some previous musings about the topic: https://github.com/patrick-kidger/optimistix/issues/92.

Generally, in Newton-type solvers (where we solve some linear equation of the form Hess * step = grad, or Jac * step = residuals), regularisation typically makes the operator more gradient-like by adding a multiple of identity. This is frequently helpful in preventing convergence to stationary points of the loss landscape, at the cost of making it slightly more sensitive to parameter scalings, at least wherever the regulariser is a scalar multiple of the identity.

Now, more concretely:

  • ClassicalTrustRegions should allow users to specify an initial_step_size, rather than always defaulting to 1.0. In the coming weeks I can try to put together a MRE why this is important, but I believe in unnormalized parameter spaces convergence should be much faster if this can be tuned.

Do you propose making a change in this line? https://github.com/patrick-kidger/optimistix/blob/b414db226f6ade1854808c40e4a6c4fc4aba7f71/optimistix/_solver/trust_region.py#L75

The way it is written, this will only impact the very first step taken, not the first-step-after-an-accepted-step. (This was also a topic of discussion in the aforementioned thread.) If you are instead interested in tuning the first checked step size after the previous iterate has been accepted, then the change would be slightly more involved but this can definitely be done. If the latter, then I think this would be sufficiently different from the existing ClassicalTrustRegion that it should probably be its own search, tweaking the step method. I'd be happy to take a PR on this.

  • DampedNewtonDescent should define the Levenberg-Marquardt parameter as unit_step_size / step_size, rather than 1 / step_size. 1 may not be meaningful in the case I describe.

In this case, it might be better to write a different descent that differs from the existing DampedNewtonDescent in this respect, and perhaps comes with its own function to define what a unit step size is. Can this step size be identified in descent.init, i.e. is this constant throughout the solve?

Since you bring up the comparison to the Scipy implementation, there is one other difference: Scipy normalises the equations, i.e. they use Jac^T Jac to get a square operator. By contrast, we solve a linear least squares problem Jac * step = residuals that is usually overdetermined, since the residual function will typically have many more elements than y does. (In other words, the Jacobian is tall and narrow.) Choosing one of these two approaches will probably impact how the regulariser acts on the system.

Happy to take a PR on this, too, and I can conditionally support this - by which I mean that I will probably only have time to read papers and think more deeply about this in mid-December, if this is required. That said, it does not sound too complicated to me.

  • Solvers should allow users to optionally pass their respective AbstractSearch instances. I want to have the option to change ClassicalTrustRegion parameters. This also seems appropriate for BacktrackingArmijo in the BFGS solvers.

So this can already be done in two ways - the first is simply

import optimistix as optx
import equinox as eqx

solver = optx.LevenbergMarquardt()
search = optx.LearningRate(learning_rate=0.4)

solver = eqx.tree_at(lambda s: s.search, solver, search

which will not throw an error and already allows you to (hackily) switch searches. The other way that we support this is through the definition of custom solvers, which is only slightly more verbose. (I'm not inclined to support/upstream a union type Search | None as in your fork, but I'm guessing that this is done for experimentation purposes.)

johannahaffner avatar Nov 08 '25 09:11 johannahaffner

Thanks for the thoughtful reply on this! Long reply incoming… so please no rush to get back to me!

Do you propose making a change in this line?

optimistix/optimistix/_solver/trust_region.py

Line 75 in b414db2

return _TrustRegionState(step_size=jnp.array(1.0)) The way it is written, this will only impact the very first step taken, not the first-step-after-an-accepted-step. (This was also a topic of discussion in the aforementioned thread.)

Yes! This should prevent a number of rejected iterations where we have some prior knowledge about how the problem should behave. I have a problem where I know in general my problem should be more gradient-like in initial iterations. My iterations are very expensive, and there are typically about 10-15 that are thrown out without changing this line.

Happy to take a PR on this, too, and I can conditionally support this - by which I mean that I will probably only have time to read papers and think more deeply about this in mid-December, if this is required. That said, it does not sound too complicated to me.

I came across this work, which describes 1) The issue of scaling and bad geometry 2) How choices of the $D_k$ scaling matrix addresses this. Moreover, on reading through the literature implementing scale-invariance though this $D_k$ seems to be canon in LM implementations. It think it’s a big part of its success; we should get it implemented!

FWIW I talk with other JAX/Equinox/Physical sciences devs and the issue of scale-invariance is one we often run into. Typically we have to hack together a solution to solve this, but this not usually the case in standard scientific computing tools. MINPACK LM has powered many applications because its scale-invariance makes it “just work” out of the box. It would be great to have this in JAX. See the argument x_scale in scipy.least_squares: it defaults to ’jac’ for LM, which essentially uses the jacobian to iteratively estimate $D_k$.

How about I can get an implementation working in my fork as I’ll need this for my work, and then we can evaluate when you have some more time?

So this can already be done in two ways - the first is simply

import optimistix as optx import equinox as eqx

solver = optx.LevenbergMarquardt() search = optx.LearningRate(learning_rate=0.4)

solver = eqx.tree_at(lambda s: s.search, solver, search which will not throw an error and already allows you to (hackily) switch searches. The other way that we support this is through the definition of custom solvers, which is only slightly more verbose. (I'm not inclined to support/upstream a union type Search | None as in your fork, but I'm guessing that this is done for experimentation purposes.)

To first make sure we are on the same page, let me clarify: the Search | None in __init__ is just to allow the passing of a default search instance, so users can instantiate one at a different set of parameters than the default. If None is passed, the default search instance is used.

I have to be honest, not including this is a strange design choice to me. It may be biased toward object-oriented thinking, rather than the functional thinking that JAX/Equinox lends itself to. Let me explain what I mean by this.

If we view a ClassicalTrustRegion instance with object-oriented thinking, it would be strange to pass it directly to the init. If I mutate the reference stored globally after init this will change the reference inside the object. This is undesirable for obvious reasons, and of course does not happen in JAX.

For JAX/Equinox code, at init-time I tend to think about Modules as any pytree of parameters. Passing the ClassicalTrustRegion to init means that I am passing a set of parameters, where the None case means default parameters. If we were not in pytree + functional world, I view this as analogous to passing a dictionary of default settings.

Moreover, from reading through the literature (see this work again), it seems that trust region parameters are highly problem dependent. For example, they state that depending on how large a problem is, different choices of parameters ate important. I typically consider tree_at calls to be for edge-cases and hacking pytrees, but if this is the case tuning trust region parameters is not necessarily advanced API for LM. By contrast, changing the descent AutoLinearSolver to something else seems like an edge case and would be appropriate for tree_at.

However, if this is not how you think about tree_at I get it; there can be advantages to this as well! In this case, I would just add to the LM docstring that changing the search instance to a ClassicalTrustRegion with different parameters is desirable. Perhaps this would be appropriate for BFGS too. This would help users less inclined to reading the source code.

michael-0brien avatar Nov 08 '25 14:11 michael-0brien

You're very welcome, and thank you for the thoughtful contribution as well!

Trust region parameters

Do you propose making a change in this line?

Yes! This should prevent a number of rejected iterations where we have some prior knowledge about how the problem should behave. I have a problem where I know in general my problem should be more gradient-like in initial iterations. My iterations are very expensive, and there are typically about 10-15 that are thrown out without changing this line.

Then it makes sense to expose this. This should become an attribute of ClassicalTrustRegion, and an argument to the __init__ methods of the solvers that use it (Levenberg-Marquardt and its indirect variant, Dogleg), which is then passed to the search much like e.g. the linear_solver argument is passed to the descent.

This brings me to the second part of your comment. I want to make a distinction between a particular search and the values of its parameters here. To me, a solver is a composition of a search and a descent. For this reason, I would not want to have an optx.BFGS that accepts any type of search - if the type of search is changed, then we have a different (and probably hybrid) solver, and this should be made obvious. Now, the values of the (hyper)parameters used by a particular search very much do not make the solver a different thing. And I agree that the fact that we do not expose them explicitly makes it unduly difficult to alter these. This could be changed by adding a search_parameters argument to the __init__ methods, like this:

from jaxtyping import ScalarLike

class LevenbergMarquardt(...):
    # attributes...

    def __init__(
        # other attributes / arguments
        search_parameters: dict[str, ScalarLike],
        verbose: frozenset[str],
    ):
        # other attributes / arguments
        self.search = optx.ClassicalTrustRegion(**search_parameters)
        self.verbose = verbose

This will both give access to all search attributes (four for the classical trust region, plus the new init step size one you're proposing), and throw an error for misnamed attributes, or ones that don't actually exist (unexpected keyword). And I think this is what you are getting at when you write

If we were not in pytree + functional world, I view this as analogous to passing a dictionary of default settings.

In which case we're in complete agreement.

(For hybridising solvers and switching searches/descents/everything: I should probably stop advertising my tree_at experimentation hack, this isn't recommended practice unless you know what you're doing and are too lazy to write <10 lines of custom solver definition. I'm not using it in any code that is meant to survive the day, and it is not meant as an alternative to long-term support for things like exposing search parameters in our APIs.)

Scaled Levenberg-Marquardt

I came across this work, which describes 1) The issue of scaling and bad geometry 2) How choices of the D k scaling matrix addresses this. Moreover, on reading through the literature implementing scale-invariance though this D k seems to be canon in LM implementations. It think it’s a big part of its success; we should get it implemented!

Mark Transtrum wrote a number of papers on this! I haven't read this one, but in general I have a good impression of his work on the topic.

How about I can get an implementation working in my fork as I’ll need this for my work, and then we can evaluate when you have some more time?

That would be perfect!

johannahaffner avatar Nov 09 '25 08:11 johannahaffner

Then it makes sense to expose this. This should become an attribute of ClassicalTrustRegion, and an argument to the __init__ methods of the solvers that use it (Levenberg-Marquardt and its indirect variant, Dogleg), which is then passed to the search much like e.g. the linear_solver argument is passed to the descent.

Sounds good!

This brings me to the second part of your comment. I want to make a distinction between a particular search and the values of its parameters here. To me, a solver is a composition of a search and a descent. For this reason, I would not want to have an optx.BFGS that accepts any type of search - if the type of search is changed, then we have a different (and probably hybrid) solver, and this should be made obvious. Now, the values of the (hyper)parameters used by a particular search very much do not make the solver a different thing. And I agree that the fact that we do not expose them explicitly makes it unduly difficult to alter these. This could be changed by adding a search_parameters argument to the __init__ methods, like this:

from jaxtyping import ScalarLike

class LevenbergMarquardt(...): # attributes...

def __init__(
    # other attributes / arguments
    search_parameters: dict[str, ScalarLike],
    verbose: frozenset[str],
):
    # other attributes / arguments
    self.search = optx.ClassicalTrustRegion(**search_parameters)
    self.verbose = verbose

Let me make sure I'm being clear. I'm not suggesting that any search instance can be passed to __init__, but rather something like the following:

class LevenbergMarquardt(...):
    # attributes...

    def __init__(
        # other attributes / arguments
        verbose: frozenset[str],
        *,
        search: ClassicalTrustRegion | None = None,
    ):
        # other attributes / arguments
        self.search = search or ClassicalTrustRegion()  # shorthand for `ClassicalTrustRegion() if search is None else search`
        self.verbose = verbose

This to me is a bit cleaner, easier for users to interpret, and easier on developers as it requires less in terms of docstrings and error checks. I also think one of the elegant parts of Equinox is its ability to avoid intermediate pytrees in cases like this! However, it's true that if the user is duck typing they can violate this in practice. For me, a type-safe approach is enough!

Regardless, I think the dictionary approach is also fine. I will defer to your judgement on this.

How about I can get an implementation working in my fork as I’ll need this for my work, and then we can evaluate when you have some more time?

That would be perfect!

Great! I am not sure yet what to do in terms of API, I think there would be a few options here. I think it would be appropriate to add some kind of option in the existing LevenbergMarquardt and IndirectLevenbergMarquardt classes that are passed on to the DampedNewtonDescent. What I don't know is if this should be something like a string option with multiple "modes", e.g. scaling_mode = 'minpack' or scaling_mode = 'none', or if it should allow more modularity through passing an update function for the diagonal scaling operator. I think the latter is probably better from reading Transtrum et al. In this case we would have to have different provided options for these update functions in the optimistix namespace.

For getting started, to be honest I don't yet understand how to implement the diagonal scaling operator in the isinstance(f_info, FunctionInfo.ResidualJac) as in LM, only the isinstance(f_info, FunctionInfo.EvalGradHessian) case as I am using for my hybrid solver I asked you about in #181. For now I'll just get the latter case working and then we can discuss!

michael-0brien avatar Nov 09 '25 15:11 michael-0brien

Regardless, I think the dictionary approach is also fine. I will defer to your judgement on this.

Then let's go with the dictionary of search options, I think this is a little more convenient.

Great! I am not sure yet what to do in terms of API, I think there would be a few options here. I think it would be appropriate to add some kind of option in the existing LevenbergMarquardt and IndirectLevenbergMarquardt classes that are passed on to the DampedNewtonDescent.

I think here I would prefer a separate descent, rather than a modification of the existing DampedNewtonDescent. This new descent will probably have to normalise the equations, i.e. work with Jac^T Jac as a square operator, and hence be conceptually quite different. I think in this case it makes for cleaner code to write something separate, perhaps re-using components, rather than adding options. We can then either make it the new default descent in LevenbergMarquardt, or add a new solver to the public API, e.g. a ScaledLevenbergMarquardt or some other descriptive name.

For getting started, to be honest I don't yet understand how to implement the diagonal scaling operator in the isinstance(f_info, FunctionInfo.ResidualJac) as in LM, only the isinstance(f_info, FunctionInfo.EvalGradHessian) case as I am using for my hybrid solver I asked you about in #181. For now I'll just get the latter case working and then we can discuss!

Perfect! It will probably be easiest to just work on the squared operator (Jac^T Jac, as above). I'm not sure how a diagonal scaling operator would translate to the overdetermined case, and if it could still be expected to work as reliably.

johannahaffner avatar Nov 09 '25 16:11 johannahaffner

Got it! Thanks. This is all helpful for me in getting started. I’ll keep you updated.

michael-0brien avatar Nov 09 '25 17:11 michael-0brien

Hi, I have a draft working of scale-invariant LM in my fork: https://github.com/michael-0brien/optimistix/blob/main/optimistix/_solver/levenberg_marquardt.py. To get the PR to the finish line, this will need to be refined, tested, and an API will need to be implemented. For now I've just implemented the AbstractDescent.

I won't have the time to further work on these finer details for a few months, but it also would be great to have some help with these steps whenever there is time or mutual interest. Anyways, before the final merge we should make sure we understand the literature.

michael-0brien avatar Nov 10 '25 14:11 michael-0brien

Alright! Then let's pick this up when you have time.

johannahaffner avatar Nov 11 '25 08:11 johannahaffner

Sure! In carving out some time, it would be helpful for me to have a roadmap towards getting it merged + understand further what API is desired.

Some concrete questions:

  • What do we want in terms of AbstractMinimizer implementation?
  • What do we want for the default implementation of the scaling operator?
  • Is there anything special we should do for testing the method, or is it enough to add to the lists in optimistix/tests/helpers.py?

Once I have this (if anything else comes up for you), I should be able to iterate quickly on what I have in my fork.

michael-0brien avatar Nov 11 '25 14:11 michael-0brien

That makes sense, happy to provide it! Can you open a draft PR for me to comment on?

johannahaffner avatar Nov 11 '25 14:11 johannahaffner

I've not followed every aspect of the discussion here too closely, but I've certainly wanted to fix up the scale-invariance for a while. (And not dived into it.)

Having scale-invariance by default sounds desirable to me. For a test, I imagine that we could check the linearity of the step method? Either numerically, or structurally via jax.linear_transpose not raising an error?

patrick-kidger avatar Nov 11 '25 17:11 patrick-kidger

This sounds interesting; can you elaborate on this @patrick-kidger? Are you suggesting we should expect linearity in the scale-invariant case but not the current implementation? Let's continue in #189

michael-0brien avatar Nov 11 '25 19:11 michael-0brien