SciMLSensitivity.jl icon indicating copy to clipboard operation
SciMLSensitivity.jl copied to clipboard

`type TrackedAffect has no field funciter` when training hybrid models with implicit solvers

Open ThummeTo opened this issue 4 years ago • 1 comments

The error (s. title) occurs when training a hybrid model (DiffEqFlux.jl) with Backward-Sensitivity-Methods like ReverseDiffAdjoint and using rootfind=RightRootFind. If this is more a DiffEqCallbacks-Issue, feel free to re-organize the thread.

Please note, that the provided MWE is a copy of the official DiffEqFlux-Tutorial for hybrid modelling. The only ~~three~~ four things modified are:

  • a VectorContinuousCallback for state events was added (with option rootfind=RightRootFind)
  • a dummy-condition condition for state callback triggering was added
  • both callbacks are added to the solve-call
  • EDIT: I forgot to say that I am using an implicit solver like Rosenbrock23(autodiff=false). For other solvers like Rodas4(autodiff=false) I still get errors, but other ones (MethodError: no method matching OrdinaryDiffEq.RodasTableau(::ReverseDiff.TrackedReal{Float32... or MethodError: no method matching (::DiffEqSensitivity.TrackedAffect{Float32...). There seems no issue with explicit solvers like e.g. Tsit5().

Thanks in advance and best regards!

MWE (teste in Julia 1.6.5 LTS with current library releases):

using DiffEqFlux, DifferentialEquations, Plots
import SciMLBase: RightRootFind

u0 = Float32[2.; 0.]
datasize = 100
tspan = (0.0f0,10.5f0)
dosetimes = [1.0,2.0,4.0,8.0]

function affect!(integrator)
    integrator.u = integrator.u.+1
end
cb_ = PresetTimeCallback(dosetimes,affect!,save_positions=(false,false))
function trueODEfunc(du,u,p,t)
    du .= -u
end
t = range(tspan[1],tspan[2],length=datasize)

prob = ODEProblem(trueODEfunc,u0,tspan)
ode_data = Array(solve(prob,Tsit5(),callback=cb_,saveat=t))
dudt2 = Chain(Dense(2,50,tanh),
             Dense(50,2))
p,re = Flux.destructure(dudt2) # use this p as the initial condition!

function dudt(du,u,p,t)
    du[1:2] .= -u[1:2]
    du[3:end] .= re(p)(u[1:2]) #re(p)(u[3:end])
end
z0 = Float32[u0;u0]
prob = ODEProblem(dudt,z0,tspan)

affect!(integrator) = integrator.u[1:2] .= integrator.u[3:end]
timecb = PresetTimeCallback(dosetimes,affect!,save_positions=(false,false))

# NEW THINGS START
function condition(out, u, t, integrator) 
    out[1] = u[1]-0.1
end

statecb = VectorContinuousCallback(condition,
                                   (integrator, idx) -> affect!(integrator),
                                   1;
                                   rootfind=RightRootFind,
                                   save_positions=(false, false),
                                   interp_points=100) 

# NEW THINGS END

function predict_n_ode()
    _prob = remake(prob,p=p)
    Array(solve(_prob,Rosenbrock23(autodiff=false),u0=z0,p=p,callback=CallbackSet(statecb, timecb),saveat=t,sensealg=ReverseDiffAdjoint() ))[1:2,:]
    #Array(solve(prob,Rosenbrock23(autodiff=false),u0=z0,p=p,saveat=t))[1:2,:]
end

function loss_n_ode()
    pred = predict_n_ode()
    loss = sum(abs2,ode_data .- pred)
    loss
end
loss_n_ode() # n_ode.p stores the initial parameters of the neural ODE

cba = function (;doplot=false) #callback function to observe training
  pred = predict_n_ode()
  display(sum(abs2,ode_data .- pred))
  # plot current prediction against data
  pl = scatter(t,ode_data[1,:],label="data")
  scatter!(pl,t,pred[1,:],label="prediction")
  display(plot(pl))
  return false
end
cba()

ps = Flux.params(p)
data = Iterators.repeated((), 200)
Flux.train!(loss_n_ode, ps, data, ADAM(0.05), cb = cba)

ThummeTo avatar Mar 29 '22 10:03 ThummeTo

@frankschae can you take a look at this one?

ChrisRackauckas avatar Mar 29 '22 11:03 ChrisRackauckas

This error occurs basically whenever a DiscreteCallback is initialized with the functioncalling_initialize routine, like e.g. the FunctionCallingCallbacks.

I narrowed it down, the problem is, that functioncalling_initialize(cb, u, t, integrator) (Code) is called by default as initialization function for discrete callbacks (or just the FunctionCallingCallbacks?). This works fine as long cb.affect! is a FunctionCallingAffect, but as soon as a TrackedAffect is deployed, the cb struct looks different, basically there is an additional layer of affect!. For example: cb.affect!.funciter in FunctionCallingAffect becomes cb.affect!.affect!.funciter in TrackedAffect.

So my proposal to fix this is:

  1. subdivide the initialization routine in DiffEqCallbacks functioncalling_initialize into two functions, so we can use multiple dispatch on the type of cb.affect!, see DiffEqCallbacks PR.
  2. extend the defined function for the callback initialization to handle TrackedAffect introduced in SciMLSensitivity correctly, see SciMLSensitivity PR.

What do you think @frankschae? Best regards, Tobi

ThummeTo avatar Mar 17 '23 09:03 ThummeTo


Forgot to say: All libraries are on the current releases. The example at the top is outdated (and was far to specific), but I hope the error itself is understandable. Let me know if I can support.

ThummeTo avatar Mar 17 '23 09:03 ThummeTo

This is closed together with #801

ThummeTo avatar Mar 27 '23 06:03 ThummeTo