diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

Handling of multiple Events triggered during one integration step

Open LuggiStruggi opened this issue 7 months ago • 7 comments

Hi Diffrax team,

I’ve noticed a subtle edge case in the current implementation regarding Events: if two cond_fns both change sign within the same integrator step, the one that appears first in the cond_fn PyTree will be treated as if it definitely occurred first—potentially ignoring the true first event. To handle this correctly, we would need to solve for the precise root time of each zero crossing, through that determine which event actually happens first and return the system state at that moment. In a case like mine ignoring this could lead to the system breaking completely (for example in the jumping ball example imagine two balls and one ball going through the floor because they hit the floor at a similar time. And once the ball is through the floor there is no coming back as the sign of the cond_fn won't change afterwards so no event is triggered).

Here a MWE:

import diffrax
import optimistix as optx

term = diffrax.ODETerm(lambda t, y, args: 1.0)
solver = diffrax.Euler()

def g1(t, y, args, **kwargs):
    return t - 1.0

def g2(t, y, args, **kwargs):
    return t - 0.5

event = diffrax.Event(
    cond_fn=(g1, g2),
    root_finder=optx.Newton(rtol=1e-4, atol=1e-4)
)

sol1 = diffrax.diffeqsolve(
    term,
    solver,
    t0=0.0,
    t1=10.0,
    dt0=2.0,
    y0=0.0,
    event=event,
)

print("ts: ", sol1.ts[-1])
print("event_mask: ", sol1.event_mask)
print("\n")

event_swapped = diffrax.Event(
    cond_fn=(g2, g1),
    root_finder=optx.Newton(rtol=1e-4, atol=1e-4)
)

sol2 = diffrax.diffeqsolve(
    term,
    solver,
    t0=0.0,
    t1=10.0,
    dt0=2.0,
    y0=0.0,
    event=event_swapped,
)

print("ts: ", sol2.ts[-1])
print("event_mask: ", sol2.event_mask)

prints:

ts:  1.0
event_mask:  (Array(True, dtype=bool), Array(False, dtype=bool))


ts:  0.5
event_mask:  (Array(True, dtype=bool), Array(False, dtype=bool))

As you can see the event at t = 0.5 is skipped in sol1.

If you think this is reasonably doable I could try to look at it.

Thanks for the help!

LuggiStruggi avatar May 23 '25 10:05 LuggiStruggi

Here for the two bouncing balls as an example why this is an issue:

If you add a small error term (small enough that events happen in the same timestep) to the height of one of the balls it either works fine if we choose the right "ball" by accident, but if bc of the stuff mentioned above one event is skipped, the ball falls through the floor. Try to add epsilon to the first vs the second balls height. Also here it would be nice if Events could be triggered only unidirectional too (https://github.com/patrick-kidger/diffrax/issues/640), as i can not set the state to 0, otherwise the event is always triggered. I feel like this could also be somewhat fixed with "vectorized" cond_fns as described in https://github.com/patrick-kidger/diffrax/issues/637, but the I'm not sure if the fix there would suffice as I think I cannot assume that the Newton solver will always find the first root only?

Here the quickly assembled MWE:

import jax.numpy as jnp
import optimistix as optx
import diffrax
from functools import partial
import matplotlib.pyplot as plt

g = 9.81
rho = 0.9
T_final = 10.0

def vector_field(t, y, args):
    x, v = y[:, 0], y[:, 1]
    a = -g * jnp.ones_like(v)
    return jnp.stack([v, a], axis=1)

def hit_ground1(t, y, args, **kwargs):
    x = y[0, 0]
    return x

def hit_ground2(t, y, args, **kwargs):
    x = y[1, 0]
    return x

root_finder = optx.Newton(1e-3, 1e-3, optx.rms_norm)
event = diffrax.Event((hit_ground1, hit_ground2), root_finder)

solver = diffrax.Euler()

ts_segments = []
ys_segments = []

t0 = 0.0
epsilon = 0.00001
y0 = jnp.array([[10, 0.0], [10 + epsilon, 0.0]])
dt0 = 0.001
t1 = T_final

while t0 < T_final:
    sol = diffrax.diffeqsolve(
        diffrax.ODETerm(vector_field),
        solver,
        t0,
        t1,
        dt0,
        y0,
        event=event,
        max_steps=2000000,
        saveat=diffrax.SaveAt(t1=True, ts=jnp.linspace(t0, t1, 100)),
    )
    mask = ~jnp.isinf(sol.ts)
    ts = sol.ts[mask]
    ys = sol.ys[mask]

    ts_segments.append(ts)
    ys_segments.append(ys)

    t_event = float(ts[-1])
    y_event = ys[-1]

    if t_event >= T_final:
        break

    x_ev, v_ev = y_event[:, 0], y_event[:, 1]
    t0 = t_event

    event_mask = jnp.stack(sol.event_mask)

    x_new = jnp.where(event_mask, 0.0000001, x_ev)
    v_new = jnp.where(event_mask, -rho * v_ev, v_ev)

    y0 = jnp.stack((x_new, v_new), axis=1)
    print(f"{t0}/{T_final}")
    
t_all = jnp.concatenate(ts_segments)
y_all = jnp.concatenate(ys_segments, axis=0)

plt.plot(t_all, y_all[:, 0, 0])
plt.plot(t_all, y_all[:, 1, 0])
plt.show()

LuggiStruggi avatar May 23 '25 14:05 LuggiStruggi

At least right now this is intended -- mostly because it was simpler to implement 😁

That said, I suspect this wouldn't be too tricky to change. Right now, during the integration, we currently select the first such event here:

https://github.com/patrick-kidger/diffrax/blob/aebc3dde11e666389d519e0213de897a540f39eb/diffrax/_integrate.py#L585-L590

Then after the integration halts (because an event function has changed signed) then we backtrack with a root find here, ignoring all event functions except the one that triggered:

https://github.com/patrick-kidger/diffrax/blob/aebc3dde11e666389d519e0213de897a540f39eb/diffrax/_integrate.py#L709

In order to change this I imagine that we could allow multiple events to trigger during the integration, and then adjust the root-finding logic to be robust to this (so that the unique root is the minimum of all zero-crossings for the triggered event functions).

I'd be happy to take a PR on this.

patrick-kidger avatar May 23 '25 20:05 patrick-kidger

Okay thanks! I guess I will have a look at it then! :)

LuggiStruggi avatar May 26 '25 10:05 LuggiStruggi

and then adjust the root-finding logic to be robust to this (so that the unique root is the minimum of all zero-crossings for the triggered event functions)

Any idea how to formulate this? I just don't have a good idea how that would be possible by just changing the fn for the root finder. If you have a good idea that would help me alot ^^ (same as last comment here: https://github.com/patrick-kidger/diffrax/issues/637).

Thanks so much for your help!

As a start I think it would already be beneficial to start the root finder at y0=final_state.event_tprev, since that will be closer to the first root than the potential other roots. Currently the code uses final_state.event_tnext as initial time for the root finder:

_event_root_find = optx.root_find(
    _to_root_find,
    event.root_finder,
    y0=final_state.event_tnext,
    options=_options,
    throw=False,
)

Which could reduce the issues a bit hopefully (made that change here: https://github.com/patrick-kidger/diffrax/pull/644) Edit: With this change I do get the same result as before. Did i change the correct value?

Here the vectorized bouncing ball:

import jax.numpy as jnp
import optimistix as optx
import diffrax
from functools import partial
import matplotlib.pyplot as plt

g = 9.81
rho = 0.9
T_final = 10.0

def vector_field(t, y, args):
    x, v = y[:, 0], y[:, 1]
    a = -g * jnp.ones_like(v)
    return jnp.stack([v, a], axis=1)

def hit_ground(t, y, args, **kwargs):
    x = y[:, 0]
    return jnp.min(x)

root_finder = optx.Newton(1e-3, 1e-3, optx.rms_norm)
event = diffrax.Event((hit_ground), root_finder)#, (False))

solver = diffrax.Euler()

ts_segments = []
ys_segments = []

t0 = 0.0
epsilon = 0.000001
y0 = jnp.array([[10 + epsilon, 0.0], [10, 0.0]])
dt0 = 0.001
t1 = T_final

while t0 < T_final:
    sol = diffrax.diffeqsolve(
        diffrax.ODETerm(vector_field),
        solver,
        t0,
        t1,
        dt0,
        y0,
        event=event,
        max_steps=2000000,
        saveat=diffrax.SaveAt(t1=True, ts=jnp.linspace(t0, t1, 100)),
    )
    mask = ~jnp.isinf(sol.ts)
    ts = sol.ts[mask]
    ys = sol.ys[mask]

    ts_segments.append(ts)
    ys_segments.append(ys)

    t_event = float(ts[-1])
    y_event = ys[-1]

    if t_event >= T_final:
        break

    x_ev, v_ev = y_event[:, 0], y_event[:, 1]
    t0 = t_event

    if sol.event_mask:
        event_idx = jnp.argmin(x_ev)
        x_ev = x_ev.at[event_idx].set(0.00001)
        v_ev = v_ev.at[event_idx].set(-rho * v_ev[event_idx])

    y0 = jnp.stack((x_ev, v_ev), axis=1)
    print(f"{t0}/{T_final}")
    
t_all = jnp.concatenate(ts_segments)
y_all = jnp.concatenate(ys_segments, axis=0)

plt.plot(t_all, y_all[:, 0, 0])
plt.plot(t_all, y_all[:, 1, 0])
plt.show()
``

LuggiStruggi avatar May 27 '25 07:05 LuggiStruggi

Okay I implemented a first draft to run a root finder per triggered_event, which makes it now work correctly but also significantly slowed my usecase down even more (https://github.com/patrick-kidger/diffrax/pull/645).. Also i dont know if the way I did it is the most elegant.. Any other ideas?

LuggiStruggi avatar May 28 '25 10:05 LuggiStruggi

So I think this should be possible with a single root find.

The key idea is that over the course of the step, then we know (a) which events triggered, and (b) whether the triggered events happened due to an upcrossing or a downcrossing.

Assume without loss of generality that all the triggered events happened due to upcrossings. Then evaluating out events at any given time t, we know that we are not at the first-triggered event if any of our triggered event functions are positive. If they are then they've already triggered!

So given triggered event values e_i then something like min(abs(e_i) for i) + sum_i relu(e_i) might be suitable? At least one of the e_i must be zero for the first term to be zero, and all of the e_i must be nonpositive for the second term to be zero.

I'm just putting something together off the top of my head, there might be literature on the best way to handle this detail.

patrick-kidger avatar May 30 '25 18:05 patrick-kidger

Wow nice, to me this seems to me like a good solution to this! Then would you suggest to change the current event logic to turn a list of condition fn to a scalar cond fn? I guess then for upcrossings we could after the min term add sum_i relu(ei), for downcrossings sum_i relu(-ei) and for bidirectional sum_i relu(ei) + sum_i relu(-ei) accounting already for https://github.com/patrick-kidger/diffrax/issues/640. Would you then use argmin(abs(e)) to get the proper event mask as discussed in https://github.com/patrick-kidger/diffrax/issues/637? I could try to put something together like that.

Edit: Didn't think of that at first, but here again the issue is that this function has no sign change and therefore will not trigger an event (at least thats the problem for my vectorized case but i guess here we could give that only to the root finder)

Edit2: Actually upon further inspection i noticed that min() should be totally sufficient for a unique first root + a sign change assuming a downcrossing. Same for max() with an upcrossing. I guess only bidirectional is the problem? The issue i had in my example problem was actually that my root finder was not sensitive enough, so it sometimes found a time, if integrated until would in some occasions already have the other ball in the negative domain -> therefore there was never a sign change in the next started integration for that event and the ball just fell to -infinity. I don't think we can make this more secure than that.

My only concern with min and all other "cond_fn wrappers" we talked about is that it might be more difficult to solve for the root solver since it has jumps in its gradient?

LuggiStruggi avatar Jun 02 '25 11:06 LuggiStruggi