How to properly scale with multiple event conditions?
Hi diffrax team,
I am working on a problem, where the amount of cond_fn scales with my state dimension. Unfortunately, this currently scales pretty badly timewise even when jitting the code. Do you have any idea what could be done to improve here?
Here as an example:
\dot{y} + y = 0
y(0) \sim [0.9, 1.0]
event condition (any state):
y - 0.1 = 0
Solving with Euler and a fixed stepsize of 0.005ms between 0 and 30ms (or until first event), it scales really badly with the state dimension:
Dim | Time (s)
--------------------
1 | 0.001022
2 | 0.000625
4 | 0.000657
8 | 0.001162
16 | 0.001870
32 | 0.004110
64 | 0.010566
128 | 0.044464
200 | 0.218422
see for comparison, the timing where i just used the cond_fn on the first element of the state y[0]:
Dim | Time (s)
--------------------
1 | 0.000723
2 | 0.000612
4 | 0.000595
8 | 0.000615
16 | 0.000596
32 | 0.000716
64 | 0.000903
128 | 0.000724
200 | 0.001835
Do you see any way to speed this up a reasonable amount? Do you think it would be smh possible to additionally support vectorized condition functions for events? Like:
def cond_fn(t, y, args, **kwargs):
return y - 0.1
Thanks for the help!
Here is the MWE i used (Single cond_fn commented out)
import jax
import jax.numpy as jnp
import diffrax as dfx
import optimistix as optx
from timeit import default_timer as timer
def solve_once(y0, t0, t1, dt0, solver, stepsize_controller, event):
term = dfx.ODETerm(lambda t, y, args: -y)
return dfx.diffeqsolve(
terms=term,
solver=solver,
t0=t0,
t1=t1,
dt0=dt0,
y0=y0,
args=None,
stepsize_controller=stepsize_controller,
event=event,
saveat=dfx.SaveAt(t0=False, t1=True, steps=False),
max_steps=3000000,
throw=True,
)
solve_once = jax.jit(solve_once, static_argnames=["solver", "stepsize_controller", "event"])
def benchmark_diffeqsolve_with_event():
key = jax.random.PRNGKey(0)
t0, t1, dt0 = 0.0, 30.0, 0.005
solver = dfx.Euler()
controller = dfx.ConstantStepSize()
dims = [1, 2, 4, 8, 16, 32, 64, 128, 200]
print(f"{'Dim':>6} | {'Time (s)':>10}")
print("-" * 20)
for n in dims:
key, subkey = jax.random.split(key)
y0 = jax.random.uniform(subkey, shape=(n,), minval=0.9, maxval=1.1)
cond_fns = [lambda t, y, *args, i=i, **kwargs: y[i] - 0.01 for i in range(n)]
# cond_fns = lambda t, y, *args, **kwargs: y[0] - 0.01
event = dfx.Event(cond_fns, root_finder=optx.Newton(rtol=1e-4, atol=1e-4))
_ = solve_once(y0, t0, t1, dt0, solver, controller, event) # warm-up
start = timer()
_ = solve_once(y0, t0, t1, dt0, solver, controller, event)
end = timer()
print(f"{n:6d} | {end - start:10.6f}")
benchmark_diffeqsolve_with_event()
I think scaling with the number of condition functions is probably inevitable unfortunately.
But a single vectorized condition function should already be totally doable as a user! Untested but something like this should probably work for you:
def cond_fn(t, y, args, **kwargs):
return jnp.min((y - 0.1)**2)
In this case you just need a function that moves smoothly to 0 when any of your components are near 0.1. (EDIT: or jnp.prod(y - 0.1) might also work?)
Hi Patrick,
def cond_fn(t, y, args, **kwargs):
return jnp.prod(y - 0.1)
already seems to speed things up alot. One big issue here is this way I won't be able to tell which element of y caused the event. Another problem is that with an increase in the y dimension N and multiple ys below 1, i get an event at the wrong time as for a sufficiently big N and lets say all ys being 0.15
(0.15 - 0.1)^N \approx 0
also
def cond_fn(t, y, args, **kwargs):
return jnp.min((y - 0.1)**2)
never works, bc it never crosses 0 and we don't get a sign change for which is i think checked in _integrate.py line 561:
event_mask_i = jnp.sign(old_event_value_i) != jnp.sign(new_event_value_i)
Anyways
def cond_fn(t, y, args, **kwargs):
return jnp.min(y - 0.1)
seems to work, but back to the issue that i won't know which element caused the spike.
I could do smth like
event_idx = jnp.argmin(sol.ys - 0.1)
afterwards but that seems a bit duplicate. It would be nice if the event mask we are getting back could also have elements that are vectors if the respective cond_fn was vector valued as well.
So for example given a scalar valued cond_fn and a vector valued cond_fn as cond_fn=[scalar_fn, vector_fn] it would return an event_mask like:
[ Array(False, dtype=bool), Array([False, True, False, False], dtype=bool) ]
Or do you think that is unnecessary? It would be certainly more elegant to define for some problems than doing a list comprehension of lambdas or smth similar over all y dimensions.
Do you think that this extension is very difficult to implement? If you think its reasonably doable i could try to have a look at it. Thanks for your Help! :)
[Moved what was here to new issue]
I could do smth like event_idx = jnp.argmin(sol.ys - 0.1) afterwards but that seems a bit duplicate.
I think that's probably the best way to handle it! I think returning more than 'which event triggered' is probably beyond the scope of the event API.
Okay! The only worry I have with this approach is smh similar to https://github.com/patrick-kidger/diffrax/issues/639. In case of multiple roots, how can i be certain the root finder will pick up the very first? Any idea on how to handle that? :) Thanks for the help!
I think you can probably use a similar combining trick like the one I describe in https://github.com/patrick-kidger/diffrax/issues/639#issuecomment-2923070660 . WDYT?
Thanks! Yes that seems really reasonable.
Edit: I'm struggling to find a function that has both: this unique first root + a sign change at 0 crossing