Updating initial guess when using nonlinear solver inside ODE term
Hi,
I think this is partially related to https://github.com/patrick-kidger/diffrax/issues/60 as it involves storing some information after an accepted step, but the difference is that I actually need to access the last known information from inside the ODE function.
I have an ODE function that requires the use of a nonlinear solver to compute the derivatives. At the moment, I'm using a fixed initial guess for NewtonNonlinearSolver, but this is inefficient. What I'd like to do is, after an accepted step, store the found root and use it as the initial guess during the next integration step. I was doing this in torchdiffeq successfully, but I can't see an equivalent way in Diffrax.
As a (contrived) example: the code below performs some sort of nonlinear solve, but each time with a poor initial guess (meaning it takes 10 iterations to converge at each call to the ODE function). If I set init_x = 0.9, which is a much better guess in this case, it takes two or three iterations, so the potential benefit is clear (especially for more expensive nonlinear functions). In this case, I wouldn't expect to run into weird issues with gradients, because backpropagating through NewtonNonlinearSolver shouldn't depend on the initial guess.
Thanks!
from diffrax import diffeqsolve, ODETerm, Dopri5, NewtonNonlinearSolver
import jax.debug
import jax.numpy as jnp
init_x = 0.1
nl_solver = NewtonNonlinearSolver(rtol=1e-3, atol=1e-6)
def f_nonlinear(x, y):
return jnp.cos(y * x) - x**3
def f(t, y, args):
sol = nl_solver(f_nonlinear, init_x, y)
jax.debug.print(
"t=t{t}, {n} iterations, x={x}",
t=t,
n=sol.num_steps,
x=sol.root,
)
return -sol.root
term = ODETerm(f)
solver = Dopri5()
y0 = 1.0
solution = diffeqsolve(term, solver, t0=0, t1=1, dt0=0.1, y0=y0)
Aside: this made-up example is actually quite severe, because printing sol.result shows that it usually fails to converge with init_x = 0.1, but always converges with init_x = 0.9, therefore would particularly benefit from having the ability to update the initial guess.
Second aside: I misunderstood the documentation at https://docs.kidger.site/diffrax/api/nonlinear_solver/, I was expecting the default option tolerate_nonconvergence=False to set root to something like jnp.nan in the case of nonconvergence, so was lazily not checking the return code. Additionally, the documentation could be clearer about what the return codes actually mean (e.g. referring to diffrax.RESULTS). Also, the descriptions in RESULTS specifically refer to "implicit methods" which may be confusing if they come from the Newton solver outside the context of implicit solvers. I can contribute a PR to try and clear up some of this if you like.
So you'd like to pass data between vector field evaluations. It's worth noting that this isn't really a clearly-defined notion, mathematically speaking: a diffeq solver may evaluate the vector field nonmonotonically (not just forward in time), in particular when using an adaptive solver that may reject steps.
That said, I agree that it can be very useful to be able to do this!
Supporting this kind of side-effect hasn't been a priority for JAX so far, as side-effects are quite a complicated thing to make happen in a functional framework. Nonetheless, I think this should be possible using an upcoming JAX API, that provides for stateful operations.
I've not yet tried it myself -- and it's not documented yet -- but it might suffice for this task. The operations are available here, and you can see an example of them being using in for_loop.
As for the nonlinear solvers, I actually have an overhaul of these planned myself. These are going to be dramatically improved soon.
Thank you for the response - I'll keep an eye out for when that API is documented as it should be useful, but I don't currently have the time to try and figure it out for myself!
Regarding variable step solvers, just for clarity, I was envisaging only updating the state after an accepted step, rather than any call to the ODE function. Along similar lines, in the case of a nonlinear solve inside the ODE function failing, it would be nice to have a mechanism to force an adaptive solver to reject that step (which is what I was doing in https://github.com/rtqichen/torchdiffeq/pull/210). (Edit: I've now seen there are already ways to do this in Diffrax: https://github.com/patrick-kidger/diffrax/issues/200#issuecomment-1341525056, https://github.com/patrick-kidger/diffrax/issues/194#issuecomment-1328111702)
Good to hear about the upcoming nonlinear solver changes too!