DifferentialEquations.jl
DifferentialEquations.jl copied to clipboard
Wrong integrator.t (t not in tstops) send to callback during Zygote gradient calculation.
Not entirely sure if this should fall under DifferentialEquations.jl but here we go.
I am implementing a NeuralODE using callbacks myself using DifferentialEquations (not using DiffEqFlux atm). I am noticing some strange behaviour with time steps being send to the callback function when calculating the gradient of the model with respect to a set of observerations.
Essentially I have a callback setting an a parameter (lets call it p
) from 0 to a specific value and back to 0.
As an example: at t < 0: p = 0
, a tstop at t = 0 is triggered and p = 100
, then after some interval, for example t=1, another tstop is triggered and t = 1: p = 0
.
For most evaluations, the gradient calculates fine.
I found that sometimes however, depending on the values of the other parameters, the solve call adds a timestep between my tstops, i.e. between t = 0 and t = 1, and as a result gradient calculation using Zygote errors. I found that this is due to the actual integrator time that is getting passed to the affect! function of my callback.
The condition() and affect!() are as follows (pseudo-code):
values = [100, 0]
tstops = [0, 1]
condition(u, t, p) = t in tstops
affect!(integrator) = integration.p[end] = values[findfirst(isequal(integrator.t), tstops)]
The error originates from the affect! function. Here, instead of calling this function with my tstops, the affect! function is called with the timestep in between my tstops during gradient calculation. By adding a println in the affect function I can see the following happening:
function affect!(integrator)
println("From affect!: integrator.t = ", integrator.t)
if findfirst(isequal(integrator.t), tstops) === nothing
return
end # exit the function to prevent error
integration.p[end] = values[findfirst(isequal(integrator.t), tstops)]
end
When no error occurs:
julia> Zygote.gradient(...)
# stuff from the forward call
From affect!: integrator.t = TrackedReal<9sm>(0.0, 0.0, KpU, 1, Caa)
From affect!: integrator.t = TrackedReal<6Zy>(0.0, 0.0, 2dM, 1, Jmq)
From affect!: integrator.t = TrackedReal<48h>(1.0, 0.0, E8U, 1, 70e)
From affect!: integrator.t = TrackedReal<K76>(1.0, 0.0, EW0, 1, 97m)
From affect!: integrator.t = TrackedReal<7Nk>(0.0, 0.0, LgX, 1, 6nD)
From affect!: integrator.t = TrackedReal<486>(0.0, 0.0, 1Y0, 1, DXb)
When the error occurs:
julia> Zygote.gradient(...)
# stuff from the forward call
From affect!: integrator.t = TrackedReal<9sm>(0.5, 0.0, KpU, 1, Caa) <- the value between my tstops is passed rather than 0
From affect!: integrator.t = TrackedReal<6Zy>(0.5, 0.0, 2dM, 1, Jmq) <- the value between my tstops is passed rather than 0
From affect!: integrator.t = TrackedReal<48h>(1.0, 0.0, E8U, 1, 70e)
From affect!: integrator.t = TrackedReal<K76>(1.0, 0.0, EW0, 1, 97m)
From affect!: integrator.t = TrackedReal<7Nk>(0.0, 0.0, LgX, 1, 6nD)
From affect!: integrator.t = TrackedReal<486>(0.0, 0.0, 1Y0, 1, DXb)
What is happening here? Hope someone can guide me to the correct location. I can see if I can create a MVP if this doesn't immediately ring any bells.
@frankschae can you take a look at this? I just found it on the bottom of my todo email and realized I never looked into it. But I know you've done some improvements to this part of the adjoints recently so maybe this just works?