diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

How to properly scale with multiple event conditions?

Open LuggiStruggi opened this issue 7 months ago • 7 comments

Hi diffrax team,

I am working on a problem, where the amount of cond_fn scales with my state dimension. Unfortunately, this currently scales pretty badly timewise even when jitting the code. Do you have any idea what could be done to improve here?

Here as an example:

\dot{y} + y = 0
y(0) \sim [0.9, 1.0]

event condition (any state):

y - 0.1 = 0

Solving with Euler and a fixed stepsize of 0.005ms between 0 and 30ms (or until first event), it scales really badly with the state dimension:

   Dim |   Time (s)
--------------------
     1 |   0.001022
     2 |   0.000625
     4 |   0.000657
     8 |   0.001162
    16 |   0.001870
    32 |   0.004110
    64 |   0.010566
   128 |   0.044464
   200 |   0.218422

see for comparison, the timing where i just used the cond_fn on the first element of the state y[0]:

   Dim |   Time (s)
--------------------
     1 |   0.000723
     2 |   0.000612
     4 |   0.000595
     8 |   0.000615
    16 |   0.000596
    32 |   0.000716
    64 |   0.000903
   128 |   0.000724
   200 |   0.001835

Do you see any way to speed this up a reasonable amount? Do you think it would be smh possible to additionally support vectorized condition functions for events? Like:

def cond_fn(t, y, args, **kwargs):
    return y - 0.1

Thanks for the help!

Here is the MWE i used (Single cond_fn commented out)

import jax
import jax.numpy as jnp
import diffrax as dfx
import optimistix as optx
from timeit import default_timer as timer

def solve_once(y0, t0, t1, dt0, solver, stepsize_controller, event):
    term = dfx.ODETerm(lambda t, y, args: -y)
    return dfx.diffeqsolve(
        terms=term,
        solver=solver,
        t0=t0,
        t1=t1,
        dt0=dt0,
        y0=y0,
        args=None,
        stepsize_controller=stepsize_controller,
        event=event,
        saveat=dfx.SaveAt(t0=False, t1=True, steps=False),
        max_steps=3000000,
        throw=True,
    )

solve_once = jax.jit(solve_once, static_argnames=["solver", "stepsize_controller", "event"])

def benchmark_diffeqsolve_with_event():
    key = jax.random.PRNGKey(0)
    t0, t1, dt0 = 0.0, 30.0, 0.005
    solver = dfx.Euler()
    controller = dfx.ConstantStepSize()
    dims = [1, 2, 4, 8, 16, 32, 64, 128, 200]

    print(f"{'Dim':>6} | {'Time (s)':>10}")
    print("-" * 20)

    for n in dims:
        key, subkey = jax.random.split(key)
        y0 = jax.random.uniform(subkey, shape=(n,), minval=0.9, maxval=1.1)

        cond_fns = [lambda t, y, *args, i=i, **kwargs: y[i] - 0.01 for i in range(n)]
        # cond_fns = lambda t, y, *args, **kwargs: y[0] - 0.01

        event = dfx.Event(cond_fns, root_finder=optx.Newton(rtol=1e-4, atol=1e-4))

        _ = solve_once(y0, t0, t1, dt0, solver, controller, event)  # warm-up

        start = timer()
        _ = solve_once(y0, t0, t1, dt0, solver, controller, event)
        end = timer()

        print(f"{n:6d} | {end - start:10.6f}")

benchmark_diffeqsolve_with_event()

LuggiStruggi avatar May 22 '25 12:05 LuggiStruggi

I think scaling with the number of condition functions is probably inevitable unfortunately.

But a single vectorized condition function should already be totally doable as a user! Untested but something like this should probably work for you:

def cond_fn(t, y, args, **kwargs):
    return jnp.min((y - 0.1)**2)

In this case you just need a function that moves smoothly to 0 when any of your components are near 0.1. (EDIT: or jnp.prod(y - 0.1) might also work?)

patrick-kidger avatar May 22 '25 18:05 patrick-kidger

Hi Patrick,

def cond_fn(t, y, args, **kwargs):
    return jnp.prod(y - 0.1)

already seems to speed things up alot. One big issue here is this way I won't be able to tell which element of y caused the event. Another problem is that with an increase in the y dimension N and multiple ys below 1, i get an event at the wrong time as for a sufficiently big N and lets say all ys being 0.15

(0.15 - 0.1)^N \approx  0

also

def cond_fn(t, y, args, **kwargs):
    return jnp.min((y - 0.1)**2)

never works, bc it never crosses 0 and we don't get a sign change for which is i think checked in _integrate.py line 561:

event_mask_i = jnp.sign(old_event_value_i) != jnp.sign(new_event_value_i)

Anyways

def cond_fn(t, y, args, **kwargs):
    return jnp.min(y - 0.1)

seems to work, but back to the issue that i won't know which element caused the spike.

I could do smth like

event_idx = jnp.argmin(sol.ys - 0.1)

afterwards but that seems a bit duplicate. It would be nice if the event mask we are getting back could also have elements that are vectors if the respective cond_fn was vector valued as well. So for example given a scalar valued cond_fn and a vector valued cond_fn as cond_fn=[scalar_fn, vector_fn] it would return an event_mask like:

[ Array(False, dtype=bool), Array([False, True, False, False], dtype=bool) ]

Or do you think that is unnecessary? It would be certainly more elegant to define for some problems than doing a list comprehension of lambdas or smth similar over all y dimensions.

Do you think that this extension is very difficult to implement? If you think its reasonably doable i could try to have a look at it. Thanks for your Help! :)

LuggiStruggi avatar May 23 '25 08:05 LuggiStruggi

[Moved what was here to new issue]

LuggiStruggi avatar May 23 '25 09:05 LuggiStruggi

I could do smth like event_idx = jnp.argmin(sol.ys - 0.1) afterwards but that seems a bit duplicate.

I think that's probably the best way to handle it! I think returning more than 'which event triggered' is probably beyond the scope of the event API.

patrick-kidger avatar May 23 '25 20:05 patrick-kidger

Okay! The only worry I have with this approach is smh similar to https://github.com/patrick-kidger/diffrax/issues/639. In case of multiple roots, how can i be certain the root finder will pick up the very first? Any idea on how to handle that? :) Thanks for the help!

LuggiStruggi avatar May 26 '25 10:05 LuggiStruggi

I think you can probably use a similar combining trick like the one I describe in https://github.com/patrick-kidger/diffrax/issues/639#issuecomment-2923070660 . WDYT?

patrick-kidger avatar May 30 '25 18:05 patrick-kidger

Thanks! Yes that seems really reasonable.

Edit: I'm struggling to find a function that has both: this unique first root + a sign change at 0 crossing

LuggiStruggi avatar Jun 02 '25 11:06 LuggiStruggi