Allow for dynamic `atol`
I am trying to use a Newton solver inside a diffrax simulation to solve a problem that is nearly linear (and is sometimes completely linear but I can't guarantee this in advance, think of a linearly interpolated tabular function) and would like to use cauchy_termination. However, I don't know the scale of my values in advance and this may change from cell to cell or over the course of the simulation. I have tried to address this by modifying the objective function but the increased nonlinearity causes issues.
As an example (not really an MWE as it doesn't really demonstrate the problem), let's use an actually linear objective function
solver = optx.Newton(atol=1e-14, rtol=0.0, linear_solver=lx.Diagonal())
def linear_objective_function(y, args):
return y - args["y_target"]
def relative_objective_function(y, args):
return y / args["y_target"] - 1.0
sol = optx.root_find(
***_objective_function,
solver,
jnp.logspace(0.0, 12.0, num=100),
args={"y_target": 0.9 * jnp.logspace(0.0, 12.0, num=100)},
tags=frozenset({lx.diagonal_tag}),
)
The linear objective function theoretically only requires one step (although sol.state.step = 2) to get to the optimal solution but due to floating point error and a single fixed atol may not converge in a real-world case due to to tolerance being set to low or be flagged as having converged too early in a nonlinear setting due to tolerance being set to high.
However, the relative function requires 4 steps (or more in a slightly nonlinear case) and in my testing shows larger discrepancies against finite difference even with fairly large relative bumps in y_target.
atol, rtol are type-hinted as float not array-like so my assumption is that they cannot be tracers.
I tried cauchy_termination=False where I can use rtol but this requires at least two iterations (unnecessary when the problem is actually linear) and caused another issue I can't remember off the top of my head (although this could have been coming from elsewhere and something I've since fixed).
If I understand correctly, then you do know how you would like to modify the absolute tolerance based on problem characteristics diagnosed at runtime in a particular solve.
In this case, I would recommend that you subclass optx.Newton and override its terminate method to include the tolerance-tweaking logic you want to add. Then you can call cauchy_termination with the computed value.
And yes, the tolerances are floats - to make it obvious that these are assumed to be constant throughout a solve, set by the user before handing an optimisation problem to Optimistix. You could probably change that if you wanted, I expect that this will at most require minor fixes. But it is probably also not required for what you want to do.
Actually, I think we might already support traced tolerances? At least, I think that was what I was aiming for when originally writing Optimistix.
I think this might just be a type annotation that is wrong.
This is defined here https://github.com/patrick-kidger/optimistix/blob/9927984fb8cbec77f9514fad7af076dce64e3993/optimistix/_iterate.py#L31
and since we used to require strict inheritance, it is done that way everywhere. Changing a solver attribute inside the while_loop would require a tree_at, and putting the solver into the carried state, I think?