optimistix icon indicating copy to clipboard operation
optimistix copied to clipboard

Question about `BestSoFar` wrapper and handling of `max_steps`

Open marcelosena opened this issue 6 months ago • 9 comments

Hi,

First, I'd like to say thank you for this fantastic package! The active development and quality of the documentation are truly appreciated and inspiring. I am a recent user of jax and this type of work makes you appreciate even more how incredible it is.

I have two questions:

  1. Understanding the BestSoFar wrapper: As I understand it, this wrapper returns the parameters that achieved the best objective value during the entire search, rather than the parameters from the final step. For example, if a search evaluated three points, x_1, x_2, and x_3, with objective values f(x_2) < f(x_3) < f(x_1), the BestSoFar wrapper would return x_2. Is this understanding correct?

  2. Behavior upon reaching max_steps: I've noticed that when an optimization reaches max_steps, an error is raised. This seems to differ from a library like SciPy, which would simply stop and return the last state. Is this correct? For my use case, it would be very useful to get the best result found, even if the optimization terminated by reaching max_steps. Is there a recommended way to achieve this and avoid the error?

Thank you again for your time and effort on this project.

marcelosena avatar Jul 14 '25 21:07 marcelosena

Hi,

thanks for the appreciation! It means a lot.

And yes, you can do that - you can disable this error by passing throw=False to your favourite top-level API (optx.{minimise, least_squares, ...}. In this case it is recommended to examine the solution - maybe the solver reached the maximum number of steps, but a solve may also fail for different reasons, such as failed linear solves.

You'll want to check

status = solution.result == optx.RESULT.nonlinear_max_steps_reached

and perhaps also check some other metrics to assess the quality of the solution returned. For example, you could check how different the values for y and y_eval or the function values f and f_eval are, these accessible through the solver state. If these are close, then you may be converging, even if the numerical tolerances are not yet met. (Comparing these is how Optimistix assesses convergence, too.)

johannahaffner avatar Jul 14 '25 22:07 johannahaffner

First, I'd like to say thank you for this fantastic package! The active development and quality of the documentation are truly appreciated and inspiring. I am a recent user of jax and this type of work makes you appreciate even more how incredible it is.

Thank you! That's absolutely awesome to hear.

To add on to @johannahaffner's solution, I'll note that by default the autodiff rules assume that you've obtained a solution, so autodiff'ing through a solve that hits the maximum number of steps may return the wrong gradients. If that matters to you then you can use adjoint=optx.RecursiveCheckpointAdjoint() to autodiff directly through the solve.

(@johannahaffner - I'm realising this is a subtle footgun. I think we could probably catch this by having a custom_jvp that examines if throw=False and adjoint=ImplicitAdjoint() and then raise a warning / compile time error / run time error? Need to think about the details but I think we can offer a better UX here.)

patrick-kidger avatar Jul 15 '25 10:07 patrick-kidger

(@johannahaffner - I'm realising this is a subtle footgun. I think we could probably catch this by having a custom_jvp that examines if throw=False and adjoint=ImplicitAdjoint() and then raise a warning / compile time error / run time error? Need to think about the details but I think we can offer a better UX here.)

I think I like the idea of a warning here! We have root-finds with throw=False internally in a few places, like the trust region descents. Do we do something related to this for derivatives of ODE solves with events in diffrax? IIRC they use a root-find too.

johannahaffner avatar Jul 16 '25 19:07 johannahaffner

Thanks @johannahaffner and @patrick-kidger! And regarding the BestSoFar wrapper, in the example I showed, wrapping it would return x_2, whereas without the wrap it would return x_3?

Thanks again!

marcelosena avatar Jul 17 '25 01:07 marcelosena

You're welcome!

And regarding the BestSoFar wrapper, in the example I showed, wrapping it would return x_2, whereas without the wrap it would return x_3?

In your example above, x_1 has the lowest objective value, so x_1 would be returned.

johannahaffner avatar Jul 17 '25 18:07 johannahaffner

Sorry, I edited the original post (the ordering I put assumed maximization instead of minimization...). In the case f(x_2) < f(x_3) < f(x_1), wrapping BestSoFar returns x_2, while if not it returns x_3?

marcelosena avatar Jul 17 '25 18:07 marcelosena

Alright! So yes, then it would return x_2.

The reason the last accepted iterate may have a higher objective value (for minimisation) is that loss surfaces might not be perfectly smooth (e.g. due to noisy data) and termination depends on convergence, as measured by the difference to the previous accepted iterate. So we might hit a sweet spot before we get two iterates that are close enough to fulfil the convergence criteria, and depending on the local properties of the loss surface the last accepted iterate might then not be the one with the best overall loss. The BestSoFar wrappers guard against this by keeping around a copy of the best loss seen so far and it's associated value of y.

If your objective function is well behaved, then the difference should not be dramatic, however.

johannahaffner avatar Jul 18 '25 06:07 johannahaffner

Right! Thats what I suspected, but just wanted to confirm. Thanks again!

marcelosena avatar Jul 18 '25 18:07 marcelosena

You're welcome!

johannahaffner avatar Jul 18 '25 19:07 johannahaffner