Question about `BestSoFar` wrapper and handling of `max_steps`
Hi,
First, I'd like to say thank you for this fantastic package! The active development and quality of the documentation are truly appreciated and inspiring. I am a recent user of jax and this type of work makes you appreciate even more how incredible it is.
I have two questions:
-
Understanding the
BestSoFarwrapper: As I understand it, this wrapper returns the parameters that achieved the best objective value during the entire search, rather than the parameters from the final step. For example, if a search evaluated three points,x_1,x_2, andx_3, with objective valuesf(x_2) < f(x_3) < f(x_1), theBestSoFarwrapper would returnx_2. Is this understanding correct? -
Behavior upon reaching
max_steps: I've noticed that when an optimization reachesmax_steps, an error is raised. This seems to differ from a library like SciPy, which would simply stop and return the last state. Is this correct? For my use case, it would be very useful to get the best result found, even if the optimization terminated by reachingmax_steps. Is there a recommended way to achieve this and avoid the error?
Thank you again for your time and effort on this project.
Hi,
thanks for the appreciation! It means a lot.
And yes, you can do that - you can disable this error by passing throw=False to your favourite top-level API (optx.{minimise, least_squares, ...}. In this case it is recommended to examine the solution - maybe the solver reached the maximum number of steps, but a solve may also fail for different reasons, such as failed linear solves.
You'll want to check
status = solution.result == optx.RESULT.nonlinear_max_steps_reached
and perhaps also check some other metrics to assess the quality of the solution returned.
For example, you could check how different the values for y and y_eval or the function values f and f_eval are, these accessible through the solver state. If these are close, then you may be converging, even if the numerical tolerances are not yet met. (Comparing these is how Optimistix assesses convergence, too.)
First, I'd like to say thank you for this fantastic package! The active development and quality of the documentation are truly appreciated and inspiring. I am a recent user of jax and this type of work makes you appreciate even more how incredible it is.
Thank you! That's absolutely awesome to hear.
To add on to @johannahaffner's solution, I'll note that by default the autodiff rules assume that you've obtained a solution, so autodiff'ing through a solve that hits the maximum number of steps may return the wrong gradients. If that matters to you then you can use adjoint=optx.RecursiveCheckpointAdjoint() to autodiff directly through the solve.
(@johannahaffner - I'm realising this is a subtle footgun. I think we could probably catch this by having a custom_jvp that examines if throw=False and adjoint=ImplicitAdjoint() and then raise a warning / compile time error / run time error? Need to think about the details but I think we can offer a better UX here.)
(@johannahaffner - I'm realising this is a subtle footgun. I think we could probably catch this by having a
custom_jvpthat examines ifthrow=Falseandadjoint=ImplicitAdjoint()and then raise a warning / compile time error / run time error? Need to think about the details but I think we can offer a better UX here.)
I think I like the idea of a warning here! We have root-finds with throw=False internally in a few places, like the trust region descents. Do we do something related to this for derivatives of ODE solves with events in diffrax? IIRC they use a root-find too.
Thanks @johannahaffner and @patrick-kidger! And regarding the BestSoFar wrapper, in the example I showed, wrapping it would return x_2, whereas without the wrap it would return x_3?
Thanks again!
You're welcome!
And regarding the
BestSoFarwrapper, in the example I showed, wrapping it would returnx_2, whereas without the wrap it would returnx_3?
In your example above, x_1 has the lowest objective value, so x_1 would be returned.
Sorry, I edited the original post (the ordering I put assumed maximization instead of minimization...). In the case f(x_2) < f(x_3) < f(x_1), wrapping BestSoFar returns x_2, while if not it returns x_3?
Alright! So yes, then it would return x_2.
The reason the last accepted iterate may have a higher objective value (for minimisation) is that loss surfaces might not be perfectly smooth (e.g. due to noisy data) and termination depends on convergence, as measured by the difference to the previous accepted iterate. So we might hit a sweet spot before we get two iterates that are close enough to fulfil the convergence criteria, and depending on the local properties of the loss surface the last accepted iterate might then not be the one with the best overall loss. The BestSoFar wrappers guard against this by keeping around a copy of the best loss seen so far and it's associated value of y.
If your objective function is well behaved, then the difference should not be dramatic, however.
Right! Thats what I suspected, but just wanted to confirm. Thanks again!
You're welcome!