DiffEqDocs.jl icon indicating copy to clipboard operation
DiffEqDocs.jl copied to clipboard

Rewinding integrator

Open lgravina1997 opened this issue 1 year ago • 6 comments

I am looking for a way to rewind the integrator to the state it had in a previous time. Specifically, I want to use a continuous callback to monitor some condition. When this condition is satisfied I want to go back Delta_t in time, modify the solution in that point, and start the integration back from there.

Is there a way to do so?

lgravina1997 avatar Oct 31 '23 11:10 lgravina1997

You can use reinit!

ChrisRackauckas avatar Nov 01 '23 06:11 ChrisRackauckas

Indeed reinit! serves the purpose but it has one problem that might be an interesting issue.

Consider the following simple problem:

u0 = 1
tl = LinRange(0,1,1001)
tspan = (tl[1], tl[end])
prob  = ODEProblem(f, u0, tspan)
sol     = solve(prob, Tsit5(), saveat=tl)

Now assume at t=t0=0.5 the solution is changed. We can do this with a callback:

t0=0.5
condition(u, t, integrator) = t==t0
affect!(integrator) = integrator.u = 1

cb = DiscreteCallback(condition, affect!, save_positions=(false,false))
prob = ODEProblem(f, u0, tspan)
sol    = solve(prob, Tsit5(), callback=cb, tstops=[t0,], saveat=tl)

Assume at t=t1=0.75 we discover the solution is wrong and want to go back to t=t0 and perform the correct evolution. We can do:

t1 = 0.75
u_t0 = sol.u[argmin(abs.(tl-t1))] #solution at t0

condition_1(u, t, integrator) = (t==t1) && (integrator.u > 1.2*u_t0)

function affect_1!(integrator)
    integrator.u = u_t0
    integrator.t = t0
    reinit!(integrator, integrator.u; t0=t0, erase_sol=false)
end

cb    = DiscreteCallback(condition, affect!, save_positions=(false,false))
cb1  = DiscreteCallback(condition_1, affect_1!, save_positions=(false,false))
sol1 = solve(prob, Tsit5(), callback=CallbackSet(cb, cb1), tstops=[t0,t1], saveat=tl)

The problem is the following:

  • If in the reinit! function we use erase_sol=true then the solution prior to t=t1 will be lost. We only want the solution AFTER t>=t0 to be overwritten, we want to keep the solution prior to t=t0.
  • If in the reinit! function we use erase_sol=false then the solution prior to t=t0 is retained as we want. The solution between t=t0 and t=t1 (t0<t<t1) is however also retained and this time interval is counted twice.

It would be worth with the erase_sol=true function only erasing up to the time t0 chosen with the kwarg.

Below an example of the manifestation of the problem in the case erase_sol=false

Screenshot 2023-11-07 at 15 40 03

lgravina1997 avatar Nov 07 '23 14:11 lgravina1997

I tried solving by using

function affect_1!(integrator)
    integrator.u = sol.u[select(t0, tl)]
    integrator.t = t0
    reinit!(integrator, integrator.u; t0=t0, erase_sol=false)

    idx = findlast(integrator.sol.t .<= t0)
    resize!(integrator.sol.t, idx)
    resize!(integrator.sol.u, idx)
end

but this gives the error (occurring at the end of the integration, not at the moment of resizing)

BoundsError: attempt to access 1001-element Vector{Float64} at index [1251]

Stacktrace:
  [1] getindex
    @ [./essentials.jl:13](https://vscode-remote+ssh-002dremote-002b128-002e178-002e67-002e106.vscode-resource.vscode-cdn.net/home/lgravina/phd_codes/LowRank/essentials.jl:13) [inlined]
  [2] solution_endpoint_match_cur_integrator!(integrator::OrdinaryDiffEq.OD ...
  [3] _postamble!(integrator::OrdinaryDiffEq.ODEIntegrator{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), ...

lgravina1997 avatar Nov 07 '23 16:11 lgravina1997

You'll also need to reset saveiter too. This is hitting internals so it's not going to be the most robust, though indexing the save via saveiter hasn't changed in 6 years so in theory it's fine though in practice it's hitting an internal non-public API to do this.

ChrisRackauckas avatar Nov 08 '23 07:11 ChrisRackauckas

This is perfect. Indeed it works. The only problem now is that if one has a PresetTimeCallback this interferes with the reset of the integrator. Specifically, if I take again the example from before


t1 = 0.75
u_t0 = sol.u[argmin(abs.(tl-t1))] #solution at t0

condition_1(u, t, integrator) = (t==t1) && (integrator.u > 1.2*u_t0)

function affect_1!(integrator)
    integrator.u = sol.u[select(t0, tl)]
    reinit!(integrator, integrator.u; t0=t0, erase_sol=false)

    idx = findlast(integrator.sol.t .<= t0)
    resize!(integrator.sol.t, idx)
    resize!(integrator.sol.u, idx)
    integrator.saveiter = idx
end

cb    = DiscreteCallback(condition, affect!, save_positions=(false,false))
cb1  = DiscreteCallback(condition_1, affect_1!, save_positions=(false,false))
cb2  = PresetTimeCallback(tl, x->x, save_positions=(false,false))
sol1 = solve(prob, Tsit5(), callback=CallbackSet(cb, cb1, cb2), tstops=[t0,t1], saveat=tl)

This code would work perfectly in the absence of this new PresetTimeCallback cb2 that I added. When I include this I get Tried to add a tstop that is behind the current time. This is strictly forbidden. This comes from the initialisation of cb2 that likely takes place after the reinitialisation of the integrator.

Interestingly, not using cb2 but including tstops=vcat(tl,[t0, t1]), i.e.

sol1 = solve(prob, Tsit5(), callback=CallbackSet(cb, cb1), tstops=vcat(tl,[t0, t1]), saveat=tl);

does not give this problem. Is this something that should be fixed in the PresetTimeCallback or am I mistaking?

lgravina1997 avatar Nov 09 '23 11:11 lgravina1997

I'd need an MWE for this.

ChrisRackauckas avatar Nov 26 '23 14:11 ChrisRackauckas