DiffEqBase.jl icon indicating copy to clipboard operation
DiffEqBase.jl copied to clipboard

WIP: support multiple simultaneous events in a single VectorContinuousCallback

Open MasonProtter opened this issue 1 month ago • 3 comments

This is an attempt to make VectorContinuousCallback support multiple events that trigger simultaneously as discussed on Slack.

What Works

Basic demo:

using OrdinaryDiffEqTsit5, Plots

function f!(du, u, p, t)
    (x1, x2) = u
    du[1] = -u[1]
    du[2] = -u[2]
    nothing
end

function my_condition(out, u, t, integrator)
    (; t1, t2) = integrator.p
    out[1] = t - t1
    out[2] = t - t2
end

function my_affect!(integrator, event_index)
    if event_index == 1
        integrator.u[1] += 1
    elseif event_index == 2
        integrator.u[2] += 1
    else
        error("huh? $event_index")
    end
end

function run(; t1=1.0, t2=1.0)
    cb = VectorContinuousCallback(my_condition, my_affect!, 2)
    prob = ODEProblem(f!, [1.0, 1.0], (0.0, 5.0), (;t1, t2), callback=cb)
    plot(solve(prob, Tsit5()))
end
julia> run()

Before:

image

After: image

What's not working

I think event detection based on derivatives is not quite right. E.g. here's an example where 3 balls bounce off the ground:

using OrdinaryDiffEqTsit5, Plots

function three_bouncing_balls!(du, u, p, t)
    (x1, v1, x2, v2, x3, v3) = u
    
    dx1 = @view du[1]
    dv1 = @view du[2]
    
    dx2 = @view du[3]
    dv2 = @view du[4]
    
    dx3 = @view du[5]
    dv3 = @view du[6]
    
    dx1[] = v1
    dv1[] = -9.8

    dx2[] = v2
    dv2[] = -9.8

    dx3[] = v3
    dv3[] = -9.8
    nothing
end


function bounce_condition(out, u, t, integrator)
    out[1] = u[1]
    out[2] = u[3]
    out[3] = u[5]
end

function bounce_affect!(integrator, event_index)
    if event_index == 1
        integrator.u[2] *= -1
    elseif event_index == 2
        integrator.u[4] *= -1
    elseif event_index == 3
        integrator.u[6] *= -1
    else
        error("huh? $event_index")
    end
end

function run_balls(;x1=1.0, x2=1.0, x3=1.0, multi_event=true, tspan=(4.0, 8.0), xlims=Tuple(sort(collect(tspan))))
    cb = VectorContinuousCallback(bounce_condition, bounce_affect!, 3)
    prob = ODEProblem(three_bouncing_balls!, [x1, 0.0, x2, 0.0, x3, 0.0], tspan, Float64[], callback=cb)
    sol = solve(prob, Tsit5(), maxiters=500)
    plot(sol, idxs=[1,3,5]; legend=:topright, ylims=(-1, 2), xlims)
end

It works when all balls bounce simultaneously:

julia> run_balls(x1=1.0, x2=1.0, x3=1.0)
image

but it goes wrong when they bounce near eachother:

julia> run_balls(x1=1, x2=1.001, x3=1)
ERROR: Double callback crossing floating pointer reducer errored. Report this issue.
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:44
  [2] find_callback_time(integrator::OrdinaryDiffEqCore.ODEIntegrator{…}, callback::VectorContinuousCallback{…}, counter::Int64)
    @ DiffEqBase ~/Nextcloud/Julia/sciml_stuff/DiffEqBase/src/callbacks.jl:510
  [3] macro expansion
    @ ~/Nextcloud/Julia/sciml_stuff/DiffEqBase/src/callbacks.jl:130 [inlined]
  [4] find_first_continuous_callback
    @ ~/Nextcloud/Julia/sciml_stuff/DiffEqBase/src/callbacks.jl:125 [inlined]
  [5] find_first_continuous_callback
    @ ~/Nextcloud/Julia/sciml_stuff/DiffEqBase/src/callbacks.jl:123 [inlined]
  [6] handle_callbacks!
    @ ~/Nextcloud/Julia/sciml_stuff/OrdinaryDiffEq/lib/OrdinaryDiffEqCore/src/integrators/integrator_utils.jl:379 [inlined]
  [7] _loopfooter!(integrator::OrdinaryDiffEqCore.ODEIntegrator{…})
    @ OrdinaryDiffEqCore ~/Nextcloud/Julia/sciml_stuff/OrdinaryDiffEq/lib/OrdinaryDiffEqCore/src/integrators/integrator_utils.jl:284
  [8] loopfooter!
    @ ~/Nextcloud/Julia/sciml_stuff/OrdinaryDiffEq/lib/OrdinaryDiffEqCore/src/integrators/integrator_utils.jl:248 [inlined]
  [9] solve!(integrator::OrdinaryDiffEqCore.ODEIntegrator{…})
    @ OrdinaryDiffEqCore ~/Nextcloud/Julia/sciml_stuff/OrdinaryDiffEq/lib/OrdinaryDiffEqCore/src/solve.jl:612
 [10] #__solve#49
    @ ~/Nextcloud/Julia/sciml_stuff/OrdinaryDiffEq/lib/OrdinaryDiffEqCore/src/solve.jl:7 [inlined]
 [11] __solve
    @ ~/Nextcloud/Julia/sciml_stuff/OrdinaryDiffEq/lib/OrdinaryDiffEqCore/src/solve.jl:1 [inlined]
 [12] #solve_call#25
    @ ~/Nextcloud/Julia/sciml_stuff/DiffEqBase/src/solve.jl:142 [inlined]
 [13] solve_call
    @ ~/Nextcloud/Julia/sciml_stuff/DiffEqBase/src/solve.jl:109 [inlined]
 [14] #solve_up#32
    @ ~/Nextcloud/Julia/sciml_stuff/DiffEqBase/src/solve.jl:578 [inlined]
 [15] solve_up
    @ ~/Nextcloud/Julia/sciml_stuff/DiffEqBase/src/solve.jl:555 [inlined]
 [16] #solve#31
    @ ~/Nextcloud/Julia/sciml_stuff/DiffEqBase/src/solve.jl:545 [inlined]
 [17] run_balls(; x1::Float64, x2::Float64, x3::Float64, multi_event::Bool, tspan::Tuple{…}, xlims::Tuple{…})
    @ Main ~/Nextcloud/Julia/sciml_stuff/scrap/scrap.jl:87
 [18] top-level scope
    @ REPL[9]:1
Some type information was truncated. Use `show(err)` to see complete types.

This is happening in

                            if integrator.event_last_time == counter &&
                               idx ∈ prev_simultaneous_events &&
                               abs(zero_func(bottom_t)) <=
                               100abs(integrator.last_event_error) &&
                               prev_sign_index == 1

                                # Determined that there is an event by derivative
                                # But floating point error may make the end point negative
                                bottom_t += integrator.dt * callback.repeat_nudge
                                sign_top = sign(zero_func(top_t))
                                sign(zero_func(bottom_t)) * sign_top >= zero(sign_top) &&
                                    error("Double callback crossing floating pointer reducer errored. Report this issue.")
                            end

If I manually disable that block, the code does seem to run fine, but will and miss events if I run it backwards in time.

It's not clear to me the right way to proceed with this block as I don't really understand what is going on with it, nor do I really understand the rationale how prev_sign_index was generated.

Checklist

  • [ ] Appropriate tests were added
  • [ ] Any code changes were done in a way that does not break public API
  • [ ] All documentation related to code changes were updated
  • [ ] The new code follows the contributor guidelines, in particular the SciML Style Guide and COLPRAC.
  • [ ] Any new documentation only uses public API

MasonProtter avatar Dec 02 '25 19:12 MasonProtter

Small update, I was playing around with this and found that I can get the behaviour I want by setting a smaller dtrelax:

function run_balls(;x1=1.0, x2=1.0, x3=1.0, tspan=(4.0, 8.0), xlims=Tuple(sort(collect(tspan))), dtrelax=1)
    cb = VectorContinuousCallback(bounce_condition, bounce_affect!, 3; dtrelax)
    prob = ODEProblem(three_bouncing_balls!, [x1, 0.0, x2, 0.0, x3, 0.0], tspan, Float64[], callback=cb)
    sol = solve(prob, Tsit5(), maxiters=500)
    plot(sol, idxs=[1,3,5]; legend=:topright, ylims=(-1, 2), xlims)
end
julia> run_balls(x1=1, x2=1.001, x3=1, dtrelax=0.1)
image

Running for a long period of time and then looking at the end, it seems that we're getting the correct offsets for the different bouncing balls:

julia> run_balls(x1=1, x2=1.001, x3=1, dtrelax=0.1, tspan=(4, 100), xlims=(95, 100))
image

I'm a little surprised that setting a dtrelax of e.g. 0.1 (down from the default 1) saves me from the derivative check even when the events are e.g. 100x or 1000x closer together than when it was erroring before though. E.g.

run_balls(x1=1, x2=1.00001, x3=1, dtrelax=0.1)

still runs fine, with the event for x2 being still at a distinct time from the events for x1 and x3.

MasonProtter avatar Dec 03 '25 13:12 MasonProtter

This still isn't using the Vector{Bool} form though?

ChrisRackauckas avatar Dec 03 '25 13:12 ChrisRackauckas

I found it a bit more convenient to use a Vector{Int} that gets push!-ed and empty!-ed, but I can switch it over to Vector{Bool} now.

MasonProtter avatar Dec 03 '25 13:12 MasonProtter