`NewtonChord` w/ `cauchy_termination=False` divergence test unhelfpul when `atol` is low
When gradients steepen as the root is approached it is expected that differences in y will increase step to step, this can cause issues when relying largely/entirely on rtol.
MWE
import lineax as lx
import optimistix as optx
import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
def tanh_with_args(y, args):
return jnp.tanh(y)
solver = optx.Newton(atol=1e-8, rtol=1e-5, linear_solver=lx.Diagonal(), cauchy_termination=False)
optx.root_find(tanh_with_args, solver, 0.5)
Traceback
EquinoxRuntimeError Traceback (most recent call last) Cell In[4], [line 13](vscode-notebook-cell:?execution_count=4&line=13) 11 x = jnp.arange(-3,3) 12 solver = optx.Newton(atol=1e-8, rtol=1e-5, linear_solver=lx.Diagonal(), cauchy_termination=False) ---> [13](vscode-notebook-cell:?execution_count=4&line=13) optx.root_find(tanh_with_args, solver, 0.5)EquinoxRuntimeError: Above is the stack outside of JIT. Below is the stack inside of JIT: File "/Users/jonathanbrodrick/pasteurcodes/lagradept/.venv/lib/python3.13/site-packages/optimistix/_root_find.py", line 220, in root_find return iterative_solve( fn, ...<10 lines>... rewrite_fn=_rewrite_fn, ) File "/Users/jonathanbrodrick/pasteurcodes/lagradept/.venv/lib/python3.13/site-packages/optimistix/_iterate.py", line 351, in iterative_solve sol = result.error_if(sol, result != RESULTS.successful) File "/Users/jonathanbrodrick/pasteurcodes/lagradept/.venv/lib/python3.13/site-packages/equinox/_module/_prebuilt.py", line 33, in call return self.func(self.self, *args, **kwargs) ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
equinox.EquinoxRuntimeError: Nonlinear solve diverged.
For now I will just use cauchy_termination=True
Hi @jpbrodrick89,
The relative and absolute tolerance are only used within cauchy_termination, I think. Perhaps you could clarify what you mean by relying on rtol here?
Ah, thanks for correcting this oversight on my part!
If I understand correctly, then you have a use case in which the gradients are steep near/at the root, and this can result in a spuriously flagged divergence.
That is a bit surprising to me, I think a steep gradient (and small function value) should produce smaller differences in y, at least in general. Do you perhaps have very ill-conditioned Jacobians?
rtol and atol affect the scale applied to diffsize (lines 137–139):
scale = (self.atol + self.rtol * ω(new_y).call(jnp.abs)).ω
with jax.numpy_dtype_promotion("standard"):
diffsize = self.norm((diff**ω / scale**ω).ω)
Which are then compared in lines 183–186:
rate = state.diffsize / state.diffsize_prev
factor = state.diffsize * rate / (1 - rate)
small = _small(state.diffsize)
diverged = _diverged(rate)
If atol=0 then rate is just the ratio between the new and previous relative changes (i.e. (new_y - y) / new_y and (y - ym1)/ y), however as atol is increased this ratio gets damped down a bit.
_diverged flags True when rate > 2. Therefore, atol can damp a ratio greater than 2 to a ratio between 1 and 2.
By relying on rtol, I mean where I set atol very low/zero because f and y have different scales and rtol is more generalisable.
So you mean that having a high value for the absolute tolerance means a solve that is a little more permissive and allows larger steps. That much makes sense to me, although in practice one almost always chooses the absolute tolerance to be lower than the relative tolerance, so the effect should be small.
In general, I would assume that the difference between successive values decreases throughout the solve. Do you think that without cauchy_termination, we diagnose divergence too early, either in general or when the value chosen for the absolute tolerance is low? Or do you think that the threshold of two is too strict? Doubling step sizes in a root find do seem like undesirable behaviour to me.
So I think this is expected. What is going on here is that cauchy_termination=False corresponds to a preference for failing a solve, rather than performing an expensive solve.
In particular this is useful for the root finds that occur in implicit ODE solvers, for which we're happy to try again with a smaller step size.
(FWIW we could probably adjust the divergence heuristic -- e.g. to only trigger if the residual is not also decreasing in some satisfactory manner -- if it would allow the use of this flag in a broader range of problems.)