diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

Events

Open cholberg opened this issue 1 year ago • 8 comments

Updates to how events are handled in diffrax. The main changes are:

  • There is now only one event class, Event.
  • Multiple cond_fn are supported. An event is triggered whenever one of them changes sign.
  • Supports differentiable event times if there are real-valued cond_fn and a root_finder is provided.
  • New tests for the new event implementation.

Some things that still might require a little thinking about:

  • Update documentation (and maybe examples, too).
  • What to do with outdated tests.
  • How to handle the case where the control is not smooth, e.g., Brownian motion.

cholberg avatar Mar 15 '24 09:03 cholberg

Thank you for the comments, Patrick. I have gone through and made changes accordingly. The only thing I would disagree with is your comment on not having to evaluate cond_fn. Correct me if I am wrong, but would we not need to know the sign of cond_fn evaluated at the intial condition. I.e., apriori you cannot know if an event happens when the condition function goes from negative to positive or vice versa. Furthermore, if we want cond_fn to also possibly take the state as an argument, event_result would need to be initialized after init_state is defined.

cholberg avatar Mar 30 '24 15:03 cholberg

Just letting you know that I've not forgotten about this! I'm trying to focus on getting #344 in first, and then I'm hoping to return to this. They're both quite large changes so I don't want them to step on each other's toes.

patrick-kidger avatar Apr 10 '24 21:04 patrick-kidger

Okay, #344 is in! I'd love to get this in next.

I appeciate that's rather a lot of merge conflicts. If you're able to rebase on to the latest dev branch then I'd be very happy to come back to this and start work on getting this in.

patrick-kidger avatar May 04 '24 12:05 patrick-kidger

Perfect, I rebased and squashed all the commits into a single big one. Quite a few tests are failing when I run it locally now, but I just wanted to update it so you could have a look.

Two thoughts since we last touched base:

  1. We should probably narrow down exaclty what arguments should be passed to the functions in cond_fn.
  2. We might want to make it so that the user has a way of passing arguments to the optimistix.root_find call that finds the exact event time.

cholberg avatar May 06 '24 09:05 cholberg

Hi! Maybe i am using it wrong but at the moment I can't get the root_finder to do smth. I get the same event times when i use the Newton Method as a root finder or just root_finder = None. Also the tests still pass when setting the root finder to None:

45 def test_continuous_event_time():
46     term = diffrax.ODETerm(lambda t, y, args: 1.0)
47     solver = diffrax.Tsit5()
48     t0 = 0
49     t1 = jnp.inf
50     dt0 = 1.0
51     y0 = -10.0
52 
53     def cond_fn(state, y, **kwargs):
54         assert isinstance(state.y, jax.Array)
55         return y
56 
57     #root_finder = optx.Newton(1e-5, 1e-5, optx.rms_norm)
58     root_finder = None
59     event = diffrax.Event(cond_fn, root_finder)
60     sol = diffrax.diffeqsolve(term, solver, t0, t1, dt0, y0, event=event)
61     assert jnp.all(jnp.isclose(cast(Array, sol.ts), 10.0, 1e-5))

Or in my code the resulting event times do not change if having a root finder or not:

  7 def bouncing_ball():
  8     g = 9.81
  9     damping = 0.8
 10     max_bounces = 10
 11     vx_0 = 5.0
 12 
 13     def dynamics(t, y, args):
 14         x, y, vx, vy = y
 15         dxdt = vx
 16         dydt = vy
 17         dvxdt = 0
 18         dvydt = -g
 19         return jnp.array([dxdt, dydt, dvxdt, dvydt])
 20 
 21     def cond_fn(state, y, **kwargs):
 22         return y[1] < 0
 23 
 24     y0 = jnp.array([0.0, 10.0, vx_0, 0.0])
 25     t0, t1 = 0, float('inf')
 26 
 27     times = []
 28     states = []
 29 
 30     for _ in range(max_bounces):
 31         root_finder = optx.Newton(1e-5, 1e-5)
 32         #root_finder = None
 33         event = Event(cond_fn, root_finder=root_finder)
 34         solver = Tsit5()
 35 
 36         sol = diffeqsolve(ODETerm(dynamics), solver, t0, t1, 0.01, y0, event=event)
 37 
 38         t0 = sol.ts[-1]
 39         last_y = sol.ys[-1]
 40         y0 = last_y * jnp.array([1, 0, 1, -damping])
 41         times.append(sol.ts)
 42         states.append(y0)
 43 
 44 
 45     return jnp.array(times), jnp.array(states)

Thanks for the help! (Also sorry If this is the wrong place to ask for this, just let me know where to write this) :)

LuggiStruggi avatar May 07 '24 08:05 LuggiStruggi

Hi! Maybe i am using it wrong but at the moment I can't get the root_finder to do smth. I get the same event times when i use the Newton Method as a root finder or just root_finder = None. Also the tests still pass when setting the root finder to None:

Ah, thanks for mentioning this. This is essentially due to the fact that the solution to the ODE is linear and the fact that dt0 divides 10.0 which is exactly the time at which the solution crosses 0. In other words, it just so happens that the root is exactly at the end point of the last step of the solver which is also the event time that is returned when no root finder is provided. I have tweaked the test so this is no longer the case.

Or in my code the resulting event times do not change if having a root finder or not:

In your example the cond_fn returns a boolean. In this case the returned event time is exactly the first step of the solver for which cond_fn switches sign, i.e., your event time will be n * dt0 where n is the number of steps taken until and including the point at which y[1] >= 0. Note that this is the same behaviour as when root_finder=None.

If you want continuous event times, you should specify a real-valued condition function. In your bouncing ball example, this would simply correspond to setting:

def cond_fn(state, y, **kwargs):
	return y[1]

cholberg avatar May 07 '24 09:05 cholberg

Ah i see! Was just a bit confusing that both boolean and comparison with 0 is possible. I tried a real valued cond_fn initially but then it was every time directly triggered at t=0 bc i reset the state to 0 at every jump. I guess setting it to a small value should work. Thank you! :)

LuggiStruggi avatar May 07 '24 09:05 LuggiStruggi

Regarding my confusion: Maybe this is a bad idea but maybe it would make sense to have a warning if both a boolean cond_fn and a root finder are passed to the event since this should never make sense then?

LuggiStruggi avatar May 07 '24 15:05 LuggiStruggi

If I pass multiple cond_fn, what would you suggest to determine which of them caused the Event? :) Thanks for the help!

LuggiStruggi avatar May 13 '24 07:05 LuggiStruggi

Regarding my confusion: Maybe this is a bad idea but maybe it would make sense to have a warning if both a boolean cond_fn and a root finder are passed to the event since this should never make sense then?

Hmm, I'm not too sure about this. There might be cases where you have one real-valued event function and one boolean. E.g., in the bouncing ball example we might want to add an extra function to cond_fn stopping the solve when the velocity is low enough (i.e., the ball has reached a steady state and stopped bouncing). But happy to hear your thoughts as well @patrick-kidger.

If I pass multiple cond_fn, what would you suggest to determine which of them caused the Event? :) Thanks for the help!

The solution returned by diffeqsolve has an attribute called event_mask which is a PyTree of the same structure as your cond_fn where each leaf is False if the corresponding condition function did not trigger an event and True otherwise. (Note: for now, only one leaf can be True). If you're interested in a more involved example of how this all works with multiple event handling you might want to check out this repo (specifically snn.py). This is a very much a work in progress, but it should serve as an example of how it all works. Hope that helps!

cholberg avatar May 13 '24 11:05 cholberg

Agreed on the first point. Happy to add a warning if all events are Boolean, though -- no strong feelings.

patrick-kidger avatar May 13 '24 18:05 patrick-kidger

Hi folks! Very excited about this PR, as I'm thinking about quantum jump applications for dynamiqs. I'm unfortunately running into an error If I try to pass an option different from saveat = SaveAt(t1=True) :

import diffrax as dx
import optimistix as optx
import jax
import jax.numpy as jnp

term = dx.ODETerm(lambda t, y, args: y)
solver = dx.Tsit5()
t0 = 0
t1 = 100.0
dt0 = 1.0
y0 = 1.0
ts = jnp.arange(t0, t1, dt0)

def cond_fn(state, y, **kwargs):
    assert isinstance(state.y, jax.Array)
    return y - jnp.exp(1.0)

fn = lambda t, y, args: y

subsaveat_a = dx.SubSaveAt(ts=ts, fn=fn)  # save solution regularly
subsaveat_b = dx.SubSaveAt(t1=True)  # save last state
saveat = dx.SaveAt(subs=[subsaveat_a, subsaveat_b])

root_finder = optx.Newton(1e-5, 1e-5, optx.rms_norm)
event = dx.Event(cond_fn, root_finder)
sol = dx.diffeqsolve(
    term, solver, t0, t1, dt0, y0, saveat=saveat, event=event
)

This runs into an error on line 1204 of _integrate.py

    ys = jtu.tree_map(lambda _y, _yevent: _y.at[-1].set(_yevent), ys, yevent)
ValueError: Expected list, got Traced<ShapedArray(float32[1])>with<DynamicJaxprTrace(level=1/0)>.

Thanks!

dkweiss31 avatar May 14 '24 18:05 dkweiss31

Ah, yes I forgot to handle the case where multiple SubSaveAt are passed. Should be fixed now with the latest commit. Let me know if that works for you.

cholberg avatar May 15 '24 10:05 cholberg

Thanks for the quick response! So that fixed the example I posted, however I am still running into issues on slightly more complicated examples more in line with how dynamiqs actually calls diffeqsolve and more specifically how it saves data as the simulation progresses. Here is a MWE where y saves the state and y2 saves "expectation values".

import diffrax as dx
import optimistix as optx
import jax
import jax.numpy as jnp
import equinox as eqx
from jax import Array

term = dx.ODETerm(lambda t, y, args: y + t)
solver = dx.Tsit5()
t0 = 0
t1 = 100.0
dt0 = 1.0
y0 = jnp.array([[1.0], [0.0]])
ts = jnp.arange(t0, t1, dt0)


def cond_fn(state, y, **kwargs):
    assert isinstance(state.y, jax.Array)
    norm = jnp.einsum("ij,ij->", y, y)
    return norm - jnp.exp(1.0)


class Saved(eqx.Module):
    y: Array
    y2: Array

def save_fn(t, y, args):
    ynorm = jnp.einsum("ij,ij->", y, y)
    return Saved(y, jnp.array([ynorm, 3 * ynorm]))

subsaveat_a = dx.SubSaveAt(ts=ts, fn=save_fn)  # save solution regularly
subsaveat_b = dx.SubSaveAt(t1=True)  # save last state
saveat = dx.SaveAt(subs=[subsaveat_a, subsaveat_b])

root_finder = optx.Newton(1e-5, 1e-5, optx.rms_norm)
event = dx.Event(cond_fn, root_finder)
sol = dx.diffeqsolve(
    term, solver, t0, t1, dt0, y0, saveat=saveat, event=event
)

This runs into

ValueError: Cannot broadcast to shape with fewer dimensions: arr_shape=(2, 2) shape=(2,)

Interestingly, the code runs without errors if I replace return Saved(y, jnp.array([ynorm, 3 * ynorm])) with return Saved(y, jnp.array([y, 3 * y]))

dkweiss31 avatar May 15 '24 13:05 dkweiss31

You're right I did not account for the fact that SubSaveAt.fn could return a PyTree. Should be fixed now. At least your MWE works with the latest commit.

cholberg avatar May 15 '24 17:05 cholberg

Indeed that fixed my MWE! I hate to be such a pain but I am now running into another issue, here is an example that is now much closer to the actual code I am interested in running.

import diffrax as dx
import optimistix as optx
import jax.numpy as jnp

L_op = 0.1 * jnp.array([[0.0, 1.0],
                        [0.0, 0.0]], dtype=complex)
H = 0.0 * L_op
t0 = 0
t1 = 100.0
dt0 = 1.0
y0 = jnp.array([[0.0], [1.0]], dtype=complex)


def vector_field(t, state, _args):
    L_d_L = jnp.transpose(L_op) @ L_op
    new_state = -1j * (H - 1j * 0.5 * L_d_L) @ state
    return new_state


def cond_fn(state, **kwargs):
    psi = state.y
    prob = jnp.abs(jnp.einsum("id,id->", jnp.conj(psi), psi))
    return prob - 0.95


term = dx.ODETerm(vector_field)
root_finder = optx.Newton(1e-5, 1e-5, optx.rms_norm)
event = dx.Event(cond_fn, root_finder)

sol = dx.diffeqsolve(
    term, dx.Tsit5(), t0, t1, dt0, y0, event=event
)

This runs into

equinox._errors.EqxRuntimeError: The linear solver returned non-finite (NaN or inf) output. This usually means that the
operator was not well-posed, and that the solver does not support this.

It's possible this could be due to my use of complex numbers, which as I understand are only partly supported in diffrax? However with the previous DiscreteTerminatingEvent I did not run into such errors. Note that if I change 0.95 to 0.0 I don't see this error (likely because the event is not triggered)

dkweiss31 avatar May 15 '24 23:05 dkweiss31

Complex number support is definitely still iffy. Can you try reproducing this without using them? You can still solve the same equation, mathematically speaking, just by splitting things into separate real and imaginary components.

patrick-kidger avatar May 19 '24 01:05 patrick-kidger

Right, here is the same example using the complex->real isomorphism described e.g. here (see Eq. (9)). I am getting the same error as before, so it seems then this is not a complex number issue

import diffrax as dx
import optimistix as optx
import jax.numpy as jnp


def mat_cmp_to_real(matrix):
    re_matrix = jnp.real(matrix)
    im_matrix = jnp.imag(matrix)
    top_row = jnp.hstack((re_matrix, -im_matrix))
    bottom_row = jnp.hstack((im_matrix, re_matrix))
    return jnp.vstack((top_row, bottom_row))


def vec_cmp_to_real(vector):
    re_vec = jnp.real(vector)
    im_vec = jnp.imag(vector)
    return jnp.vstack((re_vec, im_vec))


L_op = 0.1 * jnp.array([[0.0, 1.0],
                        [0.0, 0.0]], dtype=complex)
L_d_L = jnp.transpose(L_op) @ L_op
H = 0.0 * L_op
_prop = -1j * (H - 1j * 0.5 * L_d_L)
_y0 = jnp.array([[0.0], [1.0]], dtype=complex)
prop = mat_cmp_to_real(_prop)
y0 = vec_cmp_to_real(_y0)
t0 = 0
t1 = 100.0
dt0 = 1.0


def vector_field(t, state, _args):
    new_state = prop @ state
    return new_state


def cond_fn(state, **kwargs):
    psi = state.y
    prob = jnp.abs(jnp.einsum("id,id->", jnp.conj(psi), psi))
    return prob - 0.95


term = dx.ODETerm(vector_field)
root_finder = optx.Newton(1e-5, 1e-5, optx.rms_norm)
event = dx.Event(cond_fn, root_finder)

sol = dx.diffeqsolve(
    term, dx.Tsit5(), t0, t1, dt0, y0, event=event
)

dkweiss31 avatar May 19 '24 16:05 dkweiss31

Thankyou @dkweiss31! @cholberg, are you able to have a look at this example?

Anyway, as promised! Getting this in is my next priority. As such I've gone through and submitted a PR against this branch here. I don't claim that everything I've done is necessarily correct, so @cholberg I'd appreciate a review! :D

patrick-kidger avatar May 19 '24 19:05 patrick-kidger

@dkweiss31, sorry for being a little unresponsive these last couple of days! I will be quite busy the next 2-3 days, but will try to look at it asap.

@patrick-kidger, oh, that's great! But yea, deadline coming up so can't promise that I'll have time to look at it before Thursday. Will make it a priority afterwards, though :)

cholberg avatar May 20 '24 08:05 cholberg

@patrick-kidger FWIW both of my examples run without errors using your branch 387-tweaks! Only change I had to make was changing the signature of my conditional function to

def cond_fn(t, y, *args, **kwargs):

Moreover, the behavior is as expected and the solve terminates at e.g. 0.95. (Something interesting is that if I ask it to save intermediate values, the saved output includes values for times after the termination time. I assume this is as expected, as the code then backtracks to the point where cond_fn changes sign?) Thanks to you and to @cholberg for your assistance!

dkweiss31 avatar May 21 '24 12:05 dkweiss31

I'm glad it now works! I have no idea what I changed to make that happen, although I did have a variety of small bugfixes as well.

I think saving anything after the termination time should be considered a bug though :) I knew I must have missed something! Assuming the above MWE demonstrates this, I'll leave this to @cholberg to consider post-NeurIPS-deadline :)

patrick-kidger avatar May 21 '24 22:05 patrick-kidger

One more thing (Sorry might just be my bad approach). If i want to have multiple events which happen each at a specific point in time, only the last one passed to the Event class seems to be recognized:

 11 def test_continuous_terminate2():
 12     term = diffrax.ODETerm(lambda t, y, args: y)
 13     solver = diffrax.Tsit5()
 14     t0 = 0
 15     t1 = jnp.inf
 16     dt0 = 1
 17     y0 = 1.0
 18 
 19     event_times = [3, 4, 10]
 20     cond_fns = [lambda state, **kwargs: state.tprev - t for t in event_times]
 21 
 22     event = diffrax.Event(cond_fn=cond_fns)
 23     sol = diffrax.diffeqsolve(term, solver, t0, t1, dt0, y0, event=event)

In this case sol.ts will be 10.0 however when i pass event_times = [10, 4, 3], sol.ts will be 3.0. (For me I can also just pick the earliest time point in this case no problem, just behaved differently from what i expected which was that in any case the t=3.0 event should be triggered first.)

Additionally: Also i get an error in this case if the time dependent event happens (t=9) before the state-value dependent event (t=10):

 11 def test_continuous_two_events():
 12     term = diffrax.ODETerm(lambda t, y, args: 1.0)
 13     solver = diffrax.Tsit5()
 14     t0 = 0
 15     t1 = jnp.inf
 16     dt0 = 1.0
 17     y0 = -10.0
 18 
 19     event_time = 9
 20 
 21     def cond_fn_1(state, y, **kwargs):
 22         assert isinstance(state.y, jax.Array)
 23         return y
 24 
 25     def cond_fn_2(state, y, **kwargs):
 26         assert isinstance(state.y, jax.Array)
 27         return state.tprev - event_time
 28     
 29     
 30     root_finder = optx.Newton(1e-5, 1e-5, optx.rms_norm)
 31     event = diffrax.Event([cond_fn_1, cond_fn_2], root_finder)
 32     sol = diffrax.diffeqsolve(term, solver, t0, t1, dt0, y0, event=event)

-> jaxlib.xla_extension.XlaRuntimeError: The linear solver returned non-finite (NaN or inf) output.

Why do you think is this the case and what can I do to change this? I assume if I just pick t1 as the event time it still has the problem that it could overshoot in time by some value between 0 and the step size? Thank you for your help :)!

LuggiStruggi avatar May 22 '24 07:05 LuggiStruggi

For the second one that's probably on us.

For the first one, you are running afoul of https://docs.astral.sh/ruff/rules/function-uses-loop-variable/

patrick-kidger avatar May 22 '24 14:05 patrick-kidger

Cool thanks! Wow I didnt know about this :D thanks for the help!

LuggiStruggi avatar May 22 '24 19:05 LuggiStruggi

@dkweiss31, @patrick-kidger: I ran your MWE above with steps=True (I am assuming that is what you mean, but correct me if I am wrong). I actually did not notice when walking through 387-tweaks, but it turns out that the values returned by the root finding are not saved anywhere in this case which is also what caused the confusion I am expecting. This should be fixed with the latest commit, but, as suggested by @patrick-kidger, I really need to go through and write tests for all the different saving configurations in combination with events to ensure that everything works as expected. Do let me know if you encounter anymore unexpected behaviours.

cholberg avatar May 23 '24 15:05 cholberg

@LuggiStruggi: With the latest commits, your second example seems to work for me :)

cholberg avatar May 23 '24 15:05 cholberg

cool thanks :) sorry maybe i didnt pull recently enough

LuggiStruggi avatar May 23 '24 15:05 LuggiStruggi

Hey friends, pulling from the most recent version of cholberg:dev the following example fails:

import diffrax as dx
import optimistix as optx
import jax.numpy as jnp
import equinox as eqx
from jax import Array


L_op = 0.1 * jnp.array([[0.0, 1.0],
                        [0.0, 0.0]], dtype=complex)
t0 = 0
t1 = 100.0
dt0 = 1.0
y0 = jnp.array([[0.0], [1.0]], dtype=complex)
ts = jnp.arange(t0, t1, dt0)


class Saved(eqx.Module):
    y: Array

def save_fn(t, y, args):
    ynorm = jnp.einsum("ij,ij->", y, y)
    return Saved(jnp.array([ynorm,]))

subsaveat_a = dx.SubSaveAt(ts=ts, fn=save_fn)
subsaveat_b = dx.SubSaveAt(t1=True)
saveat = dx.SaveAt(subs=[subsaveat_a, subsaveat_b])


def vector_field(t, state, _args):
    L_d_L = jnp.transpose(L_op) @ L_op
    new_state = -1j * (- 1j * 0.5 * L_d_L) @ state
    return new_state


def cond_fn(t, y, *args, **kwargs):
    psi = y
    prob = jnp.abs(jnp.einsum("id,id->", jnp.conj(psi), psi))
    return prob - 0.95


term = dx.ODETerm(vector_field)
root_finder = optx.Newton(1e-5, 1e-5, optx.rms_norm)
event = dx.Event(cond_fn, root_finder)

sol = dx.diffeqsolve(
    term, dx.Tsit5(), t0, t1, dt0, y0, event=event, saveat=saveat
)

on line 715 of _integrate.py with

    save_index = final_state.save_state.save_index - 1
AttributeError: 'list' object has no attribute 'save_index'

dkweiss31 avatar May 24 '24 14:05 dkweiss31

@dkweiss31 -- thank you for the report! @cholberg -- looks like it should be a tree-map'd save_index = jtu.tree_map(lambda x: x.save_index - 1, final_state.save_state). I guess let's add "a PyTree of SubSaveAts" the list of things we should add a test for.

patrick-kidger avatar May 25 '24 21:05 patrick-kidger