WIP: support multiple simultaneous events in a single VectorContinuousCallback
This is an attempt to make VectorContinuousCallback support multiple events that trigger simultaneously as discussed on Slack.
What Works
Basic demo:
using OrdinaryDiffEqTsit5, Plots
function f!(du, u, p, t)
(x1, x2) = u
du[1] = -u[1]
du[2] = -u[2]
nothing
end
function my_condition(out, u, t, integrator)
(; t1, t2) = integrator.p
out[1] = t - t1
out[2] = t - t2
end
function my_affect!(integrator, event_index)
if event_index == 1
integrator.u[1] += 1
elseif event_index == 2
integrator.u[2] += 1
else
error("huh? $event_index")
end
end
function run(; t1=1.0, t2=1.0)
cb = VectorContinuousCallback(my_condition, my_affect!, 2)
prob = ODEProblem(f!, [1.0, 1.0], (0.0, 5.0), (;t1, t2), callback=cb)
plot(solve(prob, Tsit5()))
end
julia> run()
Before:
After:
What's not working
I think event detection based on derivatives is not quite right. E.g. here's an example where 3 balls bounce off the ground:
using OrdinaryDiffEqTsit5, Plots
function three_bouncing_balls!(du, u, p, t)
(x1, v1, x2, v2, x3, v3) = u
dx1 = @view du[1]
dv1 = @view du[2]
dx2 = @view du[3]
dv2 = @view du[4]
dx3 = @view du[5]
dv3 = @view du[6]
dx1[] = v1
dv1[] = -9.8
dx2[] = v2
dv2[] = -9.8
dx3[] = v3
dv3[] = -9.8
nothing
end
function bounce_condition(out, u, t, integrator)
out[1] = u[1]
out[2] = u[3]
out[3] = u[5]
end
function bounce_affect!(integrator, event_index)
if event_index == 1
integrator.u[2] *= -1
elseif event_index == 2
integrator.u[4] *= -1
elseif event_index == 3
integrator.u[6] *= -1
else
error("huh? $event_index")
end
end
function run_balls(;x1=1.0, x2=1.0, x3=1.0, multi_event=true, tspan=(4.0, 8.0), xlims=Tuple(sort(collect(tspan))))
cb = VectorContinuousCallback(bounce_condition, bounce_affect!, 3)
prob = ODEProblem(three_bouncing_balls!, [x1, 0.0, x2, 0.0, x3, 0.0], tspan, Float64[], callback=cb)
sol = solve(prob, Tsit5(), maxiters=500)
plot(sol, idxs=[1,3,5]; legend=:topright, ylims=(-1, 2), xlims)
end
It works when all balls bounce simultaneously:
julia> run_balls(x1=1.0, x2=1.0, x3=1.0)
but it goes wrong when they bounce near eachother:
julia> run_balls(x1=1, x2=1.001, x3=1)
ERROR: Double callback crossing floating pointer reducer errored. Report this issue.
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:44
[2] find_callback_time(integrator::OrdinaryDiffEqCore.ODEIntegrator{…}, callback::VectorContinuousCallback{…}, counter::Int64)
@ DiffEqBase ~/Nextcloud/Julia/sciml_stuff/DiffEqBase/src/callbacks.jl:510
[3] macro expansion
@ ~/Nextcloud/Julia/sciml_stuff/DiffEqBase/src/callbacks.jl:130 [inlined]
[4] find_first_continuous_callback
@ ~/Nextcloud/Julia/sciml_stuff/DiffEqBase/src/callbacks.jl:125 [inlined]
[5] find_first_continuous_callback
@ ~/Nextcloud/Julia/sciml_stuff/DiffEqBase/src/callbacks.jl:123 [inlined]
[6] handle_callbacks!
@ ~/Nextcloud/Julia/sciml_stuff/OrdinaryDiffEq/lib/OrdinaryDiffEqCore/src/integrators/integrator_utils.jl:379 [inlined]
[7] _loopfooter!(integrator::OrdinaryDiffEqCore.ODEIntegrator{…})
@ OrdinaryDiffEqCore ~/Nextcloud/Julia/sciml_stuff/OrdinaryDiffEq/lib/OrdinaryDiffEqCore/src/integrators/integrator_utils.jl:284
[8] loopfooter!
@ ~/Nextcloud/Julia/sciml_stuff/OrdinaryDiffEq/lib/OrdinaryDiffEqCore/src/integrators/integrator_utils.jl:248 [inlined]
[9] solve!(integrator::OrdinaryDiffEqCore.ODEIntegrator{…})
@ OrdinaryDiffEqCore ~/Nextcloud/Julia/sciml_stuff/OrdinaryDiffEq/lib/OrdinaryDiffEqCore/src/solve.jl:612
[10] #__solve#49
@ ~/Nextcloud/Julia/sciml_stuff/OrdinaryDiffEq/lib/OrdinaryDiffEqCore/src/solve.jl:7 [inlined]
[11] __solve
@ ~/Nextcloud/Julia/sciml_stuff/OrdinaryDiffEq/lib/OrdinaryDiffEqCore/src/solve.jl:1 [inlined]
[12] #solve_call#25
@ ~/Nextcloud/Julia/sciml_stuff/DiffEqBase/src/solve.jl:142 [inlined]
[13] solve_call
@ ~/Nextcloud/Julia/sciml_stuff/DiffEqBase/src/solve.jl:109 [inlined]
[14] #solve_up#32
@ ~/Nextcloud/Julia/sciml_stuff/DiffEqBase/src/solve.jl:578 [inlined]
[15] solve_up
@ ~/Nextcloud/Julia/sciml_stuff/DiffEqBase/src/solve.jl:555 [inlined]
[16] #solve#31
@ ~/Nextcloud/Julia/sciml_stuff/DiffEqBase/src/solve.jl:545 [inlined]
[17] run_balls(; x1::Float64, x2::Float64, x3::Float64, multi_event::Bool, tspan::Tuple{…}, xlims::Tuple{…})
@ Main ~/Nextcloud/Julia/sciml_stuff/scrap/scrap.jl:87
[18] top-level scope
@ REPL[9]:1
Some type information was truncated. Use `show(err)` to see complete types.
This is happening in
if integrator.event_last_time == counter &&
idx ∈ prev_simultaneous_events &&
abs(zero_func(bottom_t)) <=
100abs(integrator.last_event_error) &&
prev_sign_index == 1
# Determined that there is an event by derivative
# But floating point error may make the end point negative
bottom_t += integrator.dt * callback.repeat_nudge
sign_top = sign(zero_func(top_t))
sign(zero_func(bottom_t)) * sign_top >= zero(sign_top) &&
error("Double callback crossing floating pointer reducer errored. Report this issue.")
end
If I manually disable that block, the code does seem to run fine, but will and miss events if I run it backwards in time.
It's not clear to me the right way to proceed with this block as I don't really understand what is going on with it, nor do I really understand the rationale how prev_sign_index was generated.
Checklist
- [ ] Appropriate tests were added
- [ ] Any code changes were done in a way that does not break public API
- [ ] All documentation related to code changes were updated
- [ ] The new code follows the contributor guidelines, in particular the SciML Style Guide and COLPRAC.
- [ ] Any new documentation only uses public API
Small update, I was playing around with this and found that I can get the behaviour I want by setting a smaller dtrelax:
function run_balls(;x1=1.0, x2=1.0, x3=1.0, tspan=(4.0, 8.0), xlims=Tuple(sort(collect(tspan))), dtrelax=1)
cb = VectorContinuousCallback(bounce_condition, bounce_affect!, 3; dtrelax)
prob = ODEProblem(three_bouncing_balls!, [x1, 0.0, x2, 0.0, x3, 0.0], tspan, Float64[], callback=cb)
sol = solve(prob, Tsit5(), maxiters=500)
plot(sol, idxs=[1,3,5]; legend=:topright, ylims=(-1, 2), xlims)
end
julia> run_balls(x1=1, x2=1.001, x3=1, dtrelax=0.1)
Running for a long period of time and then looking at the end, it seems that we're getting the correct offsets for the different bouncing balls:
julia> run_balls(x1=1, x2=1.001, x3=1, dtrelax=0.1, tspan=(4, 100), xlims=(95, 100))
I'm a little surprised that setting a dtrelax of e.g. 0.1 (down from the default 1) saves me from the derivative check even when the events are e.g. 100x or 1000x closer together than when it was erroring before though. E.g.
run_balls(x1=1, x2=1.00001, x3=1, dtrelax=0.1)
still runs fine, with the event for x2 being still at a distinct time from the events for x1 and x3.
This still isn't using the Vector{Bool} form though?
I found it a bit more convenient to use a Vector{Int} that gets push!-ed and empty!-ed, but I can switch it over to Vector{Bool} now.