optimistix icon indicating copy to clipboard operation
optimistix copied to clipboard

`NewtonChord` w/ `cauchy_termination=False` divergence test unhelfpul when `atol` is low

Open jpbrodrick89 opened this issue 4 months ago • 6 comments

When gradients steepen as the root is approached it is expected that differences in y will increase step to step, this can cause issues when relying largely/entirely on rtol.

MWE

import lineax as lx
import optimistix as optx
import jax
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)

def tanh_with_args(y, args):
    return jnp.tanh(y)

solver = optx.Newton(atol=1e-8, rtol=1e-5, linear_solver=lx.Diagonal(), cauchy_termination=False)
optx.root_find(tanh_with_args, solver, 0.5)
Traceback EquinoxRuntimeError Traceback (most recent call last) Cell In[4], [line 13](vscode-notebook-cell:?execution_count=4&line=13) 11 x = jnp.arange(-3,3) 12 solver = optx.Newton(atol=1e-8, rtol=1e-5, linear_solver=lx.Diagonal(), cauchy_termination=False) ---> [13](vscode-notebook-cell:?execution_count=4&line=13) optx.root_find(tanh_with_args, solver, 0.5)

EquinoxRuntimeError: Above is the stack outside of JIT. Below is the stack inside of JIT: File "/Users/jonathanbrodrick/pasteurcodes/lagradept/.venv/lib/python3.13/site-packages/optimistix/_root_find.py", line 220, in root_find return iterative_solve( fn, ...<10 lines>... rewrite_fn=_rewrite_fn, ) File "/Users/jonathanbrodrick/pasteurcodes/lagradept/.venv/lib/python3.13/site-packages/optimistix/_iterate.py", line 351, in iterative_solve sol = result.error_if(sol, result != RESULTS.successful) File "/Users/jonathanbrodrick/pasteurcodes/lagradept/.venv/lib/python3.13/site-packages/equinox/_module/_prebuilt.py", line 33, in call return self.func(self.self, *args, **kwargs) ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

equinox.EquinoxRuntimeError: Nonlinear solve diverged.

For now I will just use cauchy_termination=True

jpbrodrick89 avatar Sep 10 '25 14:09 jpbrodrick89

Hi @jpbrodrick89,

The relative and absolute tolerance are only used within cauchy_termination, I think. Perhaps you could clarify what you mean by relying on rtol here?

johannahaffner avatar Sep 10 '25 14:09 johannahaffner

Ah, thanks for correcting this oversight on my part!

If I understand correctly, then you have a use case in which the gradients are steep near/at the root, and this can result in a spuriously flagged divergence.

That is a bit surprising to me, I think a steep gradient (and small function value) should produce smaller differences in y, at least in general. Do you perhaps have very ill-conditioned Jacobians?

johannahaffner avatar Sep 10 '25 15:09 johannahaffner

rtol and atol affect the scale applied to diffsize (lines 137–139):

        scale = (self.atol + self.rtol * ω(new_y).call(jnp.abs)).ω
        with jax.numpy_dtype_promotion("standard"):
            diffsize = self.norm((diff**ω / scale**ω).ω)

Which are then compared in lines 183–186:

            rate = state.diffsize / state.diffsize_prev
            factor = state.diffsize * rate / (1 - rate)
            small = _small(state.diffsize)
            diverged = _diverged(rate)

If atol=0 then rate is just the ratio between the new and previous relative changes (i.e. (new_y - y) / new_y and (y - ym1)/ y), however as atol is increased this ratio gets damped down a bit.

_diverged flags True when rate > 2. Therefore, atol can damp a ratio greater than 2 to a ratio between 1 and 2.

jpbrodrick89 avatar Sep 10 '25 15:09 jpbrodrick89

By relying on rtol, I mean where I set atol very low/zero because f and y have different scales and rtol is more generalisable.

jpbrodrick89 avatar Sep 10 '25 15:09 jpbrodrick89

So you mean that having a high value for the absolute tolerance means a solve that is a little more permissive and allows larger steps. That much makes sense to me, although in practice one almost always chooses the absolute tolerance to be lower than the relative tolerance, so the effect should be small.

In general, I would assume that the difference between successive values decreases throughout the solve. Do you think that without cauchy_termination, we diagnose divergence too early, either in general or when the value chosen for the absolute tolerance is low? Or do you think that the threshold of two is too strict? Doubling step sizes in a root find do seem like undesirable behaviour to me.

johannahaffner avatar Sep 10 '25 16:09 johannahaffner

So I think this is expected. What is going on here is that cauchy_termination=False corresponds to a preference for failing a solve, rather than performing an expensive solve.

In particular this is useful for the root finds that occur in implicit ODE solvers, for which we're happy to try again with a smaller step size.

(FWIW we could probably adjust the divergence heuristic -- e.g. to only trigger if the residual is not also decreasing in some satisfactory manner -- if it would allow the use of this flag in a broader range of problems.)

patrick-kidger avatar Sep 10 '25 20:09 patrick-kidger