diffrax
diffrax copied to clipboard
How to use DiscreteTerminatingEvent to terminate integration as soon as any ODEterm value becomes NaN?
I am using Diffrax to try to integrate magnetic field line trajectories, using a magnetic field which is defined by a spline over only in a set region of space, which will return NaN when evaluated outside of this region. I would like to set a DiscreteTerminatingEvent to stop the integration as soon as a NaN is detected in the ODEterm. How can I do this?
Note that I don't want the solver to start adaptively timestepping once NaN is detected, I simply wish to stop the integration entirely (since a NaN means the trajectory is leaving the region of validity of the magnetic field representation)
I tried doing this using
def default_terminating_event_fxn(state, **kwargs):
terms = kwargs.get("terms", lambda a, x, b: x)
return jnp.any(jnp.isnan(terms.vf(0, state.y, 0)))
but the solver still seems to ram into the point where NaNs occur and attempts to decrease the stepsize to the point that the max steps is reached. Is there a way for the events to access what the next value of the ODETerm will be, and act on that information?
I am unsure of exactly why this is occurring as I could not print out the values inside this event fxn I created, as even using jax debug print it only printed out tracer values and not actual numbers.
Thanks for the help!
MWE (Not a magnetic field but shows the behavior):
from diffrax import diffeqsolve, Dopri5, ODETerm, SaveAt, PIDController, DiscreteTerminatingEvent
import jax
import jax.numpy as jnp
from jax.lax import cond
def vector_field(t,y,args):
true_fun = lambda t,x,args: -x
false_fun = lambda t,x,args: jnp.nan*jnp.ones_like(x)
return cond(t<2.5, true_fun,false_fun,t,y,args)
term = ODETerm(vector_field)
solver = Dopri5()
saveat = SaveAt(ts=[0., 1., 2., 3.])
stepsize_controller = PIDController(rtol=1e-4, atol=1e-4)
def default_terminating_event_fxn(state, **kwargs):
terms = kwargs.get("terms", lambda a, x, b: x)
return jnp.any(jnp.isnan(terms.vf(0, state.y, 0)))
terminating_event = DiscreteTerminatingEvent(default_terminating_event_fxn)
sol = diffeqsolve(term, solver, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat,
stepsize_controller=stepsize_controller, discrete_terminating_event=terminating_event)
# hits max steps and fails with XLARuntimeError, when instead expect it to complete successfully
# after exiting once NaN is reached at t=2.5
print(sol.ts) # DeviceArray([0. , 1. , 2. ,inf ])
print(sol.ys) # DeviceArray([1. , 0.368, 0.135, inf])
Right! So the reason you're seeing this -- at least in your MWE -- is because of your definition for your termination function: you're passing in t=0 to the vector field on the line jnp.any(jnp.isnan(terms.vf(0, state.y, 0))). This means you're not actually triggering the NaN-generating vector field.
I'd probably suggest an event function that looks something like this instead:
def event_fn(state, *, terms, args, **kwargs):
small_step = jnp.abs(state.tnext - state.tprev) < 1e-3 # some tolerance
vf = terms.vf(state.tnext, state.y, args)
return small_step & jnp.any(jnp.isnan(vf))
To explain this: there is both a state.tprev and a state.tnext. These define the interval over which the next step will be attempted. (And then either accepted, in which case tprev gets updated to tnext and tnext gets set to something new, or rejected, in which case tprev stays the same and tnext is reduced.)
So we evaluate using tnext, as we're checking to see whether that's outside of the valid region. (tprev is always going to be inside the valid region, after all.)
The small-step condition is to check that we're actually close to the end of the valid region. E.g. if we set dt0=1e10 then our very first step would probably be outside the valid region, and we'd immediately terminate without doing anything at all.
Ah, that was a dumb mistake on my part, thanks for catching that! I still get the behavior I described in my project, even with correcting this error (as the t did not actually matter for my real problem, it is purely the state y)...
Thanks for the help, I will see if I can figure out a real MWE for my problem
I think I may know why I am seeing the behavior I am seeing, when I give no dtmin to the stepsize controller, I get the behavior that the integration stalls before eventually hitting the max_steps and exiting on an error. I think the reason for this is that without a minimum stepsize, the integrator can just keep taking tinier and tinier steps as it approaches the region where the NaNs occur.
I am unsure of why this happens only when the condition is on the state.y and not when it is on t but here is a (hopefully naive bug-free) MWE that shows the behavior I am talking about:
@patrick-kidger Is this expected behavior? I naively would have expected the same behavior if my condition were on t or on y, but it seems that when the condition is on y the dtmin argument is required to not be None in order to force a step into where y would become NaN. Could this be due to intermediate steps in the RK45 possibly seeing the NaN behavior of the RHS, and reducing the stepsize in t to avoid it, but then without a dtmin it can decrease the stepsize ad infinitum?
MWE:
from diffrax import diffeqsolve, Dopri5, ODETerm, SaveAt, PIDController, DiscreteTerminatingEvent
import jax
import jax.numpy as jnp
from jax.lax import cond
def vector_field(t,y,args):
true_fun = lambda t,x,args: -x
false_fun = lambda t,x,args: jnp.nan*jnp.ones_like(x)
return cond(y>0.1, true_fun,false_fun,t,y,args)
term = ODETerm(vector_field)
solver = Dopri5()
saveat = SaveAt(ts=[0., 1., 2., 3.])
stepsize_controller = PIDController(rtol=1e-4, atol=1e-4,dtmin=None)
def default_terminating_event_fxn(state, **kwargs):
terms = kwargs.get("terms", lambda a, x, b: x)
return jnp.any(jnp.isnan(terms.vf(state.tnext, state.y, 0)))
terminating_event = DiscreteTerminatingEvent(default_terminating_event_fxn)
sol = diffeqsolve(term, solver, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat,
stepsize_controller=stepsize_controller, discrete_terminating_event=terminating_event,max_steps=10000,)
# hits max steps and fails with XLARuntimeError, when instead expect it to hit discrete terminating event and exit gracefully
Fix - adding a tiny dtmin so that it is forced to step into the region where the RHS becomes NaN, therefore allowing the discrete terminating event to trigger and exiting as expected:
from diffrax import diffeqsolve, Dopri5, ODETerm, SaveAt, PIDController, DiscreteTerminatingEvent
import jax
import jax.numpy as jnp
from jax.lax import cond
def vector_field(t,y,args):
true_fun = lambda t,x,args: -x
false_fun = lambda t,x,args: jnp.nan*jnp.ones_like(x)
return cond(y>0.1, true_fun,false_fun,t,y,args)
term = ODETerm(vector_field)
solver = Dopri5()
saveat = SaveAt(ts=[0., 1., 2., 3.])
stepsize_controller = PIDController(rtol=1e-4, atol=1e-4,dtmin=1e-8)
def default_terminating_event_fxn(state, **kwargs):
terms = kwargs.get("terms", lambda a, x, b: x)
return jnp.any(jnp.isnan(terms.vf(state.tnext, state.y, 0)))
terminating_event = DiscreteTerminatingEvent(default_terminating_event_fxn)
sol = diffeqsolve(term, solver, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat,
stepsize_controller=stepsize_controller, discrete_terminating_event=terminating_event,max_steps=10000,)
# with dtmin=1e-8, eventually takes that final step into the region where NaN is given, and the integration exits
# with the discrete terminating event triggered as expected
print(sol.ts) # DeviceArray([0. , 1. , 2. ,inf ])
print(sol.ys) # DeviceArray([1. , 0.368, 0.135, inf])
For comparison, case where when the condition is on t and not on state.y, the integrator exits after successfully triggering the discrete terminating event, even though the solver has dtmin=None (i.e. expected behavior)
from diffrax import diffeqsolve, Dopri5, ODETerm, SaveAt, PIDController, DiscreteTerminatingEvent
import jax
import jax.numpy as jnp
from jax.lax import cond
def vector_field(t,y,args):
true_fun = lambda t,x,args: -x
false_fun = lambda t,x,args: jnp.nan*jnp.ones_like(x)
return cond(t<2.5, true_fun,false_fun,t,y,args)
term = ODETerm(vector_field)
solver = Dopri5()
saveat = SaveAt(ts=[0., 1., 2., 3.])
stepsize_controller = PIDController(rtol=1e-4, atol=1e-4,dtmin=None)
def default_terminating_event_fxn(state, **kwargs):
terms = kwargs.get("terms", lambda a, x, b: x)
return jnp.any(jnp.isnan(terms.vf(state.tnext, state.y, 0)))
terminating_event = DiscreteTerminatingEvent(default_terminating_event_fxn)
sol = diffeqsolve(term, solver, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat,
stepsize_controller=stepsize_controller, discrete_terminating_event=terminating_event,max_steps=10000,)
# hits max steps and fails with XLARuntimeError, when instead expect it to complete successfully
print(sol.ts) # DeviceArray([0. , 1. , 2. ,inf ])
print(sol.ys) # DeviceArray([1. , 0.368, 0.135, inf])
I still have trouble with my actual use case, but it is due to very long compile times in jit_diffeqsolve... I can make a different issue for that once I can figure out a MWE for that specific problem
Could this be due to intermediate steps in the RK45 possibly seeing the NaN behavior of the RHS, and reducing the stepsize in t to avoid it, but then without a dtmin it can decrease the stepsize ad infinitum?
Yes, I think this is it exactly. Right now Diffrax treats any NaN from the vector field as an indication that the solve is leaving the valid region, so it reduces the step size:
https://github.com/patrick-kidger/diffrax/blob/5f1978de2fefd8eea16fdb25a1837a4ce1b61ea4/diffrax/integrate.py#L229-L232
Maybe the simplest thing to do is simply to add another keyword argment y1_candidate to discrete_terminating_event, which specifies what the output of the solver was, before the step size controller kicked in. You can then check that for NaNs and halt integration when you detect any.
I'd be happy to take a PR on this.
As for very long compile times -- this typically occurs when you evaluate some large function multiple times: https://docs.kidger.site/equinox/faq/#my-model-is-slow-to-compile. If so, then you need to minimise the number of call sites.