`type TrackedAffect has no field funciter` when training hybrid models with implicit solvers
The error (s. title) occurs when training a hybrid model (DiffEqFlux.jl) with Backward-Sensitivity-Methods like ReverseDiffAdjoint and using rootfind=RightRootFind. If this is more a DiffEqCallbacks-Issue, feel free to re-organize the thread.
Please note, that the provided MWE is a copy of the official DiffEqFlux-Tutorial for hybrid modelling. The only ~~three~~ four things modified are:
- a VectorContinuousCallback for state events was added (with option
rootfind=RightRootFind) - a dummy-condition
conditionfor state callback triggering was added - both callbacks are added to the
solve-call -
EDIT: I forgot to say that I am using an implicit solver like
Rosenbrock23(autodiff=false). For other solvers likeRodas4(autodiff=false)I still get errors, but other ones (MethodError: no method matching OrdinaryDiffEq.RodasTableau(::ReverseDiff.TrackedReal{Float32...orMethodError: no method matching (::DiffEqSensitivity.TrackedAffect{Float32...). There seems no issue with explicit solvers like e.g.Tsit5().
Thanks in advance and best regards!
MWE (teste in Julia 1.6.5 LTS with current library releases):
using DiffEqFlux, DifferentialEquations, Plots
import SciMLBase: RightRootFind
u0 = Float32[2.; 0.]
datasize = 100
tspan = (0.0f0,10.5f0)
dosetimes = [1.0,2.0,4.0,8.0]
function affect!(integrator)
integrator.u = integrator.u.+1
end
cb_ = PresetTimeCallback(dosetimes,affect!,save_positions=(false,false))
function trueODEfunc(du,u,p,t)
du .= -u
end
t = range(tspan[1],tspan[2],length=datasize)
prob = ODEProblem(trueODEfunc,u0,tspan)
ode_data = Array(solve(prob,Tsit5(),callback=cb_,saveat=t))
dudt2 = Chain(Dense(2,50,tanh),
Dense(50,2))
p,re = Flux.destructure(dudt2) # use this p as the initial condition!
function dudt(du,u,p,t)
du[1:2] .= -u[1:2]
du[3:end] .= re(p)(u[1:2]) #re(p)(u[3:end])
end
z0 = Float32[u0;u0]
prob = ODEProblem(dudt,z0,tspan)
affect!(integrator) = integrator.u[1:2] .= integrator.u[3:end]
timecb = PresetTimeCallback(dosetimes,affect!,save_positions=(false,false))
# NEW THINGS START
function condition(out, u, t, integrator)
out[1] = u[1]-0.1
end
statecb = VectorContinuousCallback(condition,
(integrator, idx) -> affect!(integrator),
1;
rootfind=RightRootFind,
save_positions=(false, false),
interp_points=100)
# NEW THINGS END
function predict_n_ode()
_prob = remake(prob,p=p)
Array(solve(_prob,Rosenbrock23(autodiff=false),u0=z0,p=p,callback=CallbackSet(statecb, timecb),saveat=t,sensealg=ReverseDiffAdjoint() ))[1:2,:]
#Array(solve(prob,Rosenbrock23(autodiff=false),u0=z0,p=p,saveat=t))[1:2,:]
end
function loss_n_ode()
pred = predict_n_ode()
loss = sum(abs2,ode_data .- pred)
loss
end
loss_n_ode() # n_ode.p stores the initial parameters of the neural ODE
cba = function (;doplot=false) #callback function to observe training
pred = predict_n_ode()
display(sum(abs2,ode_data .- pred))
# plot current prediction against data
pl = scatter(t,ode_data[1,:],label="data")
scatter!(pl,t,pred[1,:],label="prediction")
display(plot(pl))
return false
end
cba()
ps = Flux.params(p)
data = Iterators.repeated((), 200)
Flux.train!(loss_n_ode, ps, data, ADAM(0.05), cb = cba)
@frankschae can you take a look at this one?
This error occurs basically whenever a DiscreteCallback is initialized with the functioncalling_initialize routine, like e.g. the FunctionCallingCallbacks.
I narrowed it down, the problem is, that functioncalling_initialize(cb, u, t, integrator) (Code) is called by default as initialization function for discrete callbacks (or just the FunctionCallingCallbacks?). This works fine as long cb.affect! is a FunctionCallingAffect, but as soon as a TrackedAffect is deployed, the cb struct looks different, basically there is an additional layer of affect!. For example:
cb.affect!.funciter in FunctionCallingAffect becomes
cb.affect!.affect!.funciter in TrackedAffect.
So my proposal to fix this is:
- subdivide the initialization routine in DiffEqCallbacks
functioncalling_initializeinto two functions, so we can use multiple dispatch on the type ofcb.affect!, see DiffEqCallbacks PR. - extend the defined function for the callback initialization to handle
TrackedAffectintroduced in SciMLSensitivity correctly, see SciMLSensitivity PR.
What do you think @frankschae? Best regards, Tobi
Forgot to say: All libraries are on the current releases. The example at the top is outdated (and was far to specific), but I hope the error itself is understandable. Let me know if I can support.
This is closed together with #801