DifferentialEquations.jl icon indicating copy to clipboard operation
DifferentialEquations.jl copied to clipboard

Wrong integrator.t (t not in tstops) send to callback during Zygote gradient calculation.

Open Janssena opened this issue 1 year ago • 1 comments

Not entirely sure if this should fall under DifferentialEquations.jl but here we go.

I am implementing a NeuralODE using callbacks myself using DifferentialEquations (not using DiffEqFlux atm). I am noticing some strange behaviour with time steps being send to the callback function when calculating the gradient of the model with respect to a set of observerations.

Essentially I have a callback setting an a parameter (lets call it p) from 0 to a specific value and back to 0. As an example: at t < 0: p = 0, a tstop at t = 0 is triggered and p = 100, then after some interval, for example t=1, another tstop is triggered and t = 1: p = 0. For most evaluations, the gradient calculates fine. I found that sometimes however, depending on the values of the other parameters, the solve call adds a timestep between my tstops, i.e. between t = 0 and t = 1, and as a result gradient calculation using Zygote errors. I found that this is due to the actual integrator time that is getting passed to the affect! function of my callback.

The condition() and affect!() are as follows (pseudo-code):

values = [100, 0]
tstops = [0, 1]
condition(u, t, p) = t in tstops
affect!(integrator) = integration.p[end] = values[findfirst(isequal(integrator.t), tstops)]

The error originates from the affect! function. Here, instead of calling this function with my tstops, the affect! function is called with the timestep in between my tstops during gradient calculation. By adding a println in the affect function I can see the following happening:

function affect!(integrator)
  println("From affect!: integrator.t = ", integrator.t)
  if findfirst(isequal(integrator.t), tstops) === nothing
    return 
  end # exit the function to prevent error
  integration.p[end] = values[findfirst(isequal(integrator.t), tstops)]
end

When no error occurs:

julia> Zygote.gradient(...)
# stuff from the forward call
From affect!: integrator.t = TrackedReal<9sm>(0.0, 0.0, KpU, 1, Caa)
From affect!: integrator.t = TrackedReal<6Zy>(0.0, 0.0, 2dM, 1, Jmq)
From affect!: integrator.t = TrackedReal<48h>(1.0, 0.0, E8U, 1, 70e)
From affect!: integrator.t = TrackedReal<K76>(1.0, 0.0, EW0, 1, 97m)
From affect!: integrator.t = TrackedReal<7Nk>(0.0, 0.0, LgX, 1, 6nD)
From affect!: integrator.t = TrackedReal<486>(0.0, 0.0, 1Y0, 1, DXb)

When the error occurs:

julia> Zygote.gradient(...)
# stuff from the forward call
From affect!: integrator.t = TrackedReal<9sm>(0.5, 0.0, KpU, 1, Caa) <- the value between my tstops is passed rather than 0
From affect!: integrator.t = TrackedReal<6Zy>(0.5, 0.0, 2dM, 1, Jmq) <- the value between my tstops is passed rather than 0
From affect!: integrator.t = TrackedReal<48h>(1.0, 0.0, E8U, 1, 70e)
From affect!: integrator.t = TrackedReal<K76>(1.0, 0.0, EW0, 1, 97m)
From affect!: integrator.t = TrackedReal<7Nk>(0.0, 0.0, LgX, 1, 6nD)
From affect!: integrator.t = TrackedReal<486>(0.0, 0.0, 1Y0, 1, DXb)

What is happening here? Hope someone can guide me to the correct location. I can see if I can create a MVP if this doesn't immediately ring any bells.

Janssena avatar Apr 24 '23 12:04 Janssena

@frankschae can you take a look at this? I just found it on the bottom of my todo email and realized I never looked into it. But I know you've done some improvements to this part of the adjoints recently so maybe this just works?

ChrisRackauckas avatar Aug 13 '23 15:08 ChrisRackauckas