DifferentialEquations.jl
DifferentialEquations.jl copied to clipboard
Thread-savety for PeriodicCallback
I'd like to apply a controller periodically in an ensemble simulation. The controller should act based on the current state of the integrator and change a parameter of the DE. However, PeriodicCallback doesn't seem to be thread save, despite using safetycopy = true.
using Flux, DiffEqFlux, DiffEqSensitivity
using DifferentialEquations
using Plots
function lotka_volterra!(du, u, p, t)
x, y = u
α, β, δ, γ = p
du[1] = dx = α*x - β*x*y
du[2] = dy = -δ*y + γ*x*y
end
# Initial condition
u0 = [1.0, 1.0]
# Simulation interval and intermediary points
tspan = (0.0, 10.0)
dt = 0.1
nn = Chain(Dense(2, 1, tanh))
p_nn, re = Flux.destructure(nn)
pars = [1.5, 1.0, 3.0, 1.0]
prob = ODEProblem{true}(lotka_volterra!, u0, tspan, pars)
sol = solve(prob, Tsit5(), adaptive=false, dt=0.05)
function loss(p)
function affect!(integrator)
integrator.p[2]=(re(p)(integrator.u))[1]
end
cb = PeriodicCallback(affect!, dt; initial_affect = true, save_positions=(false,false))
sol = solve(prob, Tsit5(), saveat = dt, callback=cb, adaptive=false, dt=0.05)
loss = sum(abs2, sol.-1)
return loss, sol
end
l1, sol1 = loss(p_nn)
function ensembleloss(p)
function affect!(integrator)
integrator.p[2]=(re(p)(integrator.u))[1]
end
cb = PeriodicCallback(affect!, dt; initial_affect = true, save_positions=(false,false))
ensembleprob = EnsembleProblem(prob,
safetycopy = true
)
sol = solve(ensembleprob, Tsit5(), ensemblealg=EnsembleThreads(), saveat = dt,
callback=cb, adaptive=false, dt=0.05, trajectories = 5)
loss = sum(abs2, sol.-1)/5
return loss, sol
end
Threads.nthreads() = 1
l2, sol2 = ensembleloss(p_nn)
with Threads.nthreads() = 1, the loss values l1 and l2 are the same (up to 1e-13).
with Threads.nthreads() = 4, the loss values and the associated trajectories can be different.
Sometimes (?), I get an error:
ERROR: TaskFailedException:
Tried to add a tstop that is behind the current time. This is strictly forbidden
Stacktrace:
[1] error(::String) at ./error.jl:33
[2] add_tstop! at /Users/frank/.julia/packages/OrdinaryDiffEq/VPJBD/src/integrators/integrator_interface.jl:96 [inlined]
[3] (::DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}})(::OrdinaryDiffEq.ODEIntegrator{Tsit5,true,Array{Float64,1},Nothing,Float64,Array{Float64,1},Float64,Float64,Float64,Array{Array{Float64,1},1},ODESolution{Float64,2,Array{Array{Float64,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{Array{Float64,1},1},1},ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,typeof(lotka_volterra!),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,typeof(lotka_volterra!),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float64,1},1},Array{Float64,1},Array{Array{Array{Float64,1},1},1},OrdinaryDiffEq.Tsit5Cache{Array{Float64,1},Array{Float64,1},Array{Float64,1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float64}}},DiffEqBase.DEStats},ODEFunction{true,typeof(lotka_volterra!),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},OrdinaryDiffEq.Tsit5Cache{Array{Float64,1},Array{Float64,1},Array{Float64,1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float64}},OrdinaryDiffEq.DEOptions{Float64,Float64,Float64,Float64,typeof(DiffEqBase.ODE_DEFAULT_NORM),typeof(LinearAlgebra.opnorm),CallbackSet{Tuple{},Tuple{DiscreteCallback{DiffEqCallbacks.var"#44#49"{Float64,Base.RefValue{Float64},Base.RefValue{Int64}},DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}},DiffEqCallbacks.var"#46#51"{Bool,DiffEqCallbacks.var"#48#53"{Bool},Float64,Base.RefValue{Float64},Base.RefValue{Int64},DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}}}}}},typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN),typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE),typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK),DataStructures.BinaryHeap{Float64,DataStructures.LessThan},DataStructures.BinaryHeap{Float64,DataStructures.LessThan},Nothing,Nothing,Int64,Tuple{},Float64,Tuple{}},Array{Float64,1},Float64,Nothing,OrdinaryDiffEq.DefaultInit}) at /Users/frank/.julia/packages/DiffEqCallbacks/b4ahb/src/iterative_and_periodic.jl:84
[4] apply_discrete_callback! at /Users/frank/.julia/packages/DiffEqBase/V7P18/src/callbacks.jl:830 [inlined]
[5] handle_callbacks!(::OrdinaryDiffEq.ODEIntegrator{Tsit5,true,Array{Float64,1},Nothing,Float64,Array{Float64,1},Float64,Float64,Float64,Array{Array{Float64,1},1},ODESolution{Float64,2,Array{Array{Float64,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{Array{Float64,1},1},1},ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,typeof(lotka_volterra!),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,typeof(lotka_volterra!),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float64,1},1},Array{Float64,1},Array{Array{Array{Float64,1},1},1},OrdinaryDiffEq.Tsit5Cache{Array{Float64,1},Array{Float64,1},Array{Float64,1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float64}}},DiffEqBase.DEStats},ODEFunction{true,typeof(lotka_volterra!),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},OrdinaryDiffEq.Tsit5Cache{Array{Float64,1},Array{Float64,1},Array{Float64,1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float64}},OrdinaryDiffEq.DEOptions{Float64,Float64,Float64,Float64,typeof(DiffEqBase.ODE_DEFAULT_NORM),typeof(LinearAlgebra.opnorm),CallbackSet{Tuple{},Tuple{DiscreteCallback{DiffEqCallbacks.var"#44#49"{Float64,Base.RefValue{Float64},Base.RefValue{Int64}},DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}},DiffEqCallbacks.var"#46#51"{Bool,DiffEqCallbacks.var"#48#53"{Bool},Float64,Base.RefValue{Float64},Base.RefValue{Int64},DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}}}}}},typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN),typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE),typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK),DataStructures.BinaryHeap{Float64,DataStructures.LessThan},DataStructures.BinaryHeap{Float64,DataStructures.LessThan},Nothing,Nothing,Int64,Tuple{},Float64,Tuple{}},Array{Float64,1},Float64,Nothing,OrdinaryDiffEq.DefaultInit}) at /Users/frank/.julia/packages/OrdinaryDiffEq/VPJBD/src/integrators/integrator_utils.jl:259
[6] _loopfooter!(::OrdinaryDiffEq.ODEIntegrator{Tsit5,true,Array{Float64,1},Nothing,Float64,Array{Float64,1},Float64,Float64,Float64,Array{Array{Float64,1},1},ODESolution{Float64,2,Array{Array{Float64,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{Array{Float64,1},1},1},ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,typeof(lotka_volterra!),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,typeof(lotka_volterra!),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float64,1},1},Array{Float64,1},Array{Array{Array{Float64,1},1},1},OrdinaryDiffEq.Tsit5Cache{Array{Float64,1},Array{Float64,1},Array{Float64,1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float64}}},DiffEqBase.DEStats},ODEFunction{true,typeof(lotka_volterra!),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},OrdinaryDiffEq.Tsit5Cache{Array{Float64,1},Array{Float64,1},Array{Float64,1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float64}},OrdinaryDiffEq.DEOptions{Float64,Float64,Float64,Float64,typeof(DiffEqBase.ODE_DEFAULT_NORM),typeof(LinearAlgebra.opnorm),CallbackSet{Tuple{},Tuple{DiscreteCallback{DiffEqCallbacks.var"#44#49"{Float64,Base.RefValue{Float64},Base.RefValue{Int64}},DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}},DiffEqCallbacks.var"#46#51"{Bool,DiffEqCallbacks.var"#48#53"{Bool},Float64,Base.RefValue{Float64},Base.RefValue{Int64},DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}}}}}},typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN),typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE),typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK),DataStructures.BinaryHeap{Float64,DataStructures.LessThan},DataStructures.BinaryHeap{Float64,DataStructures.LessThan},Nothing,Nothing,Int64,Tuple{},Float64,Tuple{}},Array{Float64,1},Float64,Nothing,OrdinaryDiffEq.DefaultInit}) at /Users/frank/.julia/packages/OrdinaryDiffEq/VPJBD/src/integrators/integrator_utils.jl:220
[7] loopfooter! at /Users/frank/.julia/packages/OrdinaryDiffEq/VPJBD/src/integrators/integrator_utils.jl:166 [inlined]
[8] solve!(::OrdinaryDiffEq.ODEIntegrator{Tsit5,true,Array{Float64,1},Nothing,Float64,Array{Float64,1},Float64,Float64,Float64,Array{Array{Float64,1},1},ODESolution{Float64,2,Array{Array{Float64,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{Array{Float64,1},1},1},ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,typeof(lotka_volterra!),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,typeof(lotka_volterra!),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float64,1},1},Array{Float64,1},Array{Array{Array{Float64,1},1},1},OrdinaryDiffEq.Tsit5Cache{Array{Float64,1},Array{Float64,1},Array{Float64,1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float64}}},DiffEqBase.DEStats},ODEFunction{true,typeof(lotka_volterra!),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},OrdinaryDiffEq.Tsit5Cache{Array{Float64,1},Array{Float64,1},Array{Float64,1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float64}},OrdinaryDiffEq.DEOptions{Float64,Float64,Float64,Float64,typeof(DiffEqBase.ODE_DEFAULT_NORM),typeof(LinearAlgebra.opnorm),CallbackSet{Tuple{},Tuple{DiscreteCallback{DiffEqCallbacks.var"#44#49"{Float64,Base.RefValue{Float64},Base.RefValue{Int64}},DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}},DiffEqCallbacks.var"#46#51"{Bool,DiffEqCallbacks.var"#48#53"{Bool},Float64,Base.RefValue{Float64},Base.RefValue{Int64},DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}}}}}},typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN),typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE),typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK),DataStructures.BinaryHeap{Float64,DataStructures.LessThan},DataStructures.BinaryHeap{Float64,DataStructures.LessThan},Nothing,Nothing,Int64,Tuple{},Float64,Tuple{}},Array{Float64,1},Float64,Nothing,OrdinaryDiffEq.DefaultInit}) at /Users/frank/.julia/packages/OrdinaryDiffEq/VPJBD/src/solve.jl:429
[9] #__solve#391 at /Users/frank/.julia/packages/OrdinaryDiffEq/VPJBD/src/solve.jl:5 [inlined]
[10] solve_call(::ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,typeof(lotka_volterra!),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem}, ::Tsit5; merge_callbacks::Bool, kwargs::Base.Iterators.Pairs{Symbol,Any,NTuple{5,Symbol},NamedTuple{(:ensemblealg, :saveat, :callback, :adaptive, :dt),Tuple{EnsembleThreads,Float64,DiscreteCallback{DiffEqCallbacks.var"#44#49"{Float64,Base.RefValue{Float64},Base.RefValue{Int64}},DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}},DiffEqCallbacks.var"#46#51"{Bool,DiffEqCallbacks.var"#48#53"{Bool},Float64,Base.RefValue{Float64},Base.RefValue{Int64},DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}}}},Bool,Float64}}}) at /Users/frank/.julia/packages/DiffEqBase/V7P18/src/solve.jl:92
[11] #solve_up#461 at /Users/frank/.julia/packages/DiffEqBase/V7P18/src/solve.jl:114 [inlined]
[12] #solve#460 at /Users/frank/.julia/packages/DiffEqBase/V7P18/src/solve.jl:102 [inlined]
[13] batch_func(::Int64, ::EnsembleProblem{ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,typeof(lotka_volterra!),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},typeof(DiffEqBase.DEFAULT_PROB_FUNC),typeof(DiffEqBase.DEFAULT_OUTPUT_FUNC),typeof(DiffEqBase.DEFAULT_REDUCTION),Nothing}, ::Tsit5; kwargs::Base.Iterators.Pairs{Symbol,Any,NTuple{5,Symbol},NamedTuple{(:ensemblealg, :saveat, :callback, :adaptive, :dt),Tuple{EnsembleThreads,Float64,DiscreteCallback{DiffEqCallbacks.var"#44#49"{Float64,Base.RefValue{Float64},Base.RefValue{Int64}},DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}},DiffEqCallbacks.var"#46#51"{Bool,DiffEqCallbacks.var"#48#53"{Bool},Float64,Base.RefValue{Float64},Base.RefValue{Int64},DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}}}},Bool,Float64}}}) at /Users/frank/.julia/packages/DiffEqBase/V7P18/src/ensemble/basic_ensemble_solve.jl:146
[14] #363 at /Users/frank/.julia/packages/DiffEqBase/V7P18/src/ensemble/basic_ensemble_solve.jl:180 [inlined]
[15] iterate at ./generator.jl:47 [inlined]
[16] _collect(::UnitRange{Int64}, ::Base.Generator{UnitRange{Int64},DiffEqBase.var"#363#364"{Base.Iterators.Pairs{Symbol,Any,NTuple{5,Symbol},NamedTuple{(:ensemblealg, :saveat, :callback, :adaptive, :dt),Tuple{EnsembleThreads,Float64,DiscreteCallback{DiffEqCallbacks.var"#44#49"{Float64,Base.RefValue{Float64},Base.RefValue{Int64}},DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}},DiffEqCallbacks.var"#46#51"{Bool,DiffEqCallbacks.var"#48#53"{Bool},Float64,Base.RefValue{Float64},Base.RefValue{Int64},DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}}}},Bool,Float64}}},EnsembleProblem{ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,typeof(lotka_volterra!),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},typeof(DiffEqBase.DEFAULT_PROB_FUNC),typeof(DiffEqBase.DEFAULT_OUTPUT_FUNC),typeof(DiffEqBase.DEFAULT_REDUCTION),Nothing},Tsit5}}, ::Base.EltypeUnknown, ::Base.HasShape{1}) at ./array.jl:699
[17] collect_similar at ./array.jl:628 [inlined]
[18] map at ./abstractarray.jl:2162 [inlined]
[19] solve_batch(::EnsembleProblem{ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,typeof(lotka_volterra!),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},typeof(DiffEqBase.DEFAULT_PROB_FUNC),typeof(DiffEqBase.DEFAULT_OUTPUT_FUNC),typeof(DiffEqBase.DEFAULT_REDUCTION),Nothing}, ::Tsit5, ::EnsembleSerial, ::UnitRange{Int64}, ::Int64; kwargs::Base.Iterators.Pairs{Symbol,Any,NTuple{5,Symbol},NamedTuple{(:ensemblealg, :saveat, :callback, :adaptive, :dt),Tuple{EnsembleThreads,Float64,DiscreteCallback{DiffEqCallbacks.var"#44#49"{Float64,Base.RefValue{Float64},Base.RefValue{Int64}},DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}},DiffEqCallbacks.var"#46#51"{Bool,DiffEqCallbacks.var"#48#53"{Bool},Float64,Base.RefValue{Float64},Base.RefValue{Int64},DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}}}},Bool,Float64}}}) at /Users/frank/.julia/packages/DiffEqBase/V7P18/src/ensemble/basic_ensemble_solve.jl:179
[20] (::DiffEqBase.var"#367#369"{Base.Iterators.Pairs{Symbol,Any,NTuple{5,Symbol},NamedTuple{(:ensemblealg, :saveat, :callback, :adaptive, :dt),Tuple{EnsembleThreads,Float64,DiscreteCallback{DiffEqCallbacks.var"#44#49"{Float64,Base.RefValue{Float64},Base.RefValue{Int64}},DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}},DiffEqCallbacks.var"#46#51"{Bool,DiffEqCallbacks.var"#48#53"{Bool},Float64,Base.RefValue{Float64},Base.RefValue{Int64},DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}}}},Bool,Float64}}},EnsembleProblem{ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,typeof(lotka_volterra!),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},typeof(DiffEqBase.DEFAULT_PROB_FUNC),typeof(DiffEqBase.DEFAULT_OUTPUT_FUNC),typeof(DiffEqBase.DEFAULT_REDUCTION),Nothing},Tsit5,UnitRange{Int64},Int64,Int64})(::Int64) at /Users/frank/.julia/packages/DiffEqBase/V7P18/src/ensemble/basic_ensemble_solve.jl:206
[21] macro expansion at /Users/frank/.julia/packages/DiffEqBase/V7P18/src/ensemble/basic_ensemble_solve.jl:214 [inlined]
[22] (::DiffEqBase.var"#509#threadsfor_fun#370"{DiffEqBase.var"#367#369"{Base.Iterators.Pairs{Symbol,Any,NTuple{5,Symbol},NamedTuple{(:ensemblealg, :saveat, :callback, :adaptive, :dt),Tuple{EnsembleThreads,Float64,DiscreteCallback{DiffEqCallbacks.var"#44#49"{Float64,Base.RefValue{Float64},Base.RefValue{Int64}},DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}},DiffEqCallbacks.var"#46#51"{Bool,DiffEqCallbacks.var"#48#53"{Bool},Float64,Base.RefValue{Float64},Base.RefValue{Int64},DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}}}},Bool,Float64}}},EnsembleProblem{ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,typeof(lotka_volterra!),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},typeof(DiffEqBase.DEFAULT_PROB_FUNC),typeof(DiffEqBase.DEFAULT_OUTPUT_FUNC),typeof(DiffEqBase.DEFAULT_REDUCTION),Nothing},Tsit5,UnitRange{Int64},Int64,Int64},Tuple{UnitRange{Int64}},Array{Any,1},UnitRange{Int64}})(::Bool) at ./threadingconstructs.jl:81
[23] (::DiffEqBase.var"#509#threadsfor_fun#370"{DiffEqBase.var"#367#369"{Base.Iterators.Pairs{Symbol,Any,NTuple{5,Symbol},NamedTuple{(:ensemblealg, :saveat, :callback, :adaptive, :dt),Tuple{EnsembleThreads,Float64,DiscreteCallback{DiffEqCallbacks.var"#44#49"{Float64,Base.RefValue{Float64},Base.RefValue{Int64}},DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}},DiffEqCallbacks.var"#46#51"{Bool,DiffEqCallbacks.var"#48#53"{Bool},Float64,Base.RefValue{Float64},Base.RefValue{Int64},DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}}}},Bool,Float64}}},EnsembleProblem{ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,typeof(lotka_volterra!),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},typeof(DiffEqBase.DEFAULT_PROB_FUNC),typeof(DiffEqBase.DEFAULT_OUTPUT_FUNC),typeof(DiffEqBase.DEFAULT_REDUCTION),Nothing},Tsit5,UnitRange{Int64},Int64,Int64},Tuple{UnitRange{Int64}},Array{Any,1},UnitRange{Int64}})() at ./threadingconstructs.jl:48
Stacktrace:
[1] wait at ./task.jl:267 [inlined]
[2] threading_run(::Function) at ./threadingconstructs.jl:34
[3] macro expansion at ./threadingconstructs.jl:93 [inlined]
[4] tmap(::Function, ::UnitRange{Int64}) at /Users/frank/.julia/packages/DiffEqBase/V7P18/src/ensemble/basic_ensemble_solve.jl:213
[5] solve_batch(::EnsembleProblem{ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,typeof(lotka_volterra!),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},typeof(DiffEqBase.DEFAULT_PROB_FUNC),typeof(DiffEqBase.DEFAULT_OUTPUT_FUNC),typeof(DiffEqBase.DEFAULT_REDUCTION),Nothing}, ::Tsit5, ::EnsembleThreads, ::UnitRange{Int64}, ::Int64; kwargs::Base.Iterators.Pairs{Symbol,Any,NTuple{5,Symbol},NamedTuple{(:ensemblealg, :saveat, :callback, :adaptive, :dt),Tuple{EnsembleThreads,Float64,DiscreteCallback{DiffEqCallbacks.var"#44#49"{Float64,Base.RefValue{Float64},Base.RefValue{Int64}},DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}},DiffEqCallbacks.var"#46#51"{Bool,DiffEqCallbacks.var"#48#53"{Bool},Float64,Base.RefValue{Float64},Base.RefValue{Int64},DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}}}},Bool,Float64}}}) at /Users/frank/.julia/packages/DiffEqBase/V7P18/src/ensemble/basic_ensemble_solve.jl:200
[6] batch_function at /Users/frank/.julia/packages/DiffEqBase/V7P18/src/ensemble/basic_ensemble_solve.jl:108 [inlined]
[7] macro expansion at ./timing.jl:233 [inlined]
[8] __solve(::EnsembleProblem{ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,typeof(lotka_volterra!),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},typeof(DiffEqBase.DEFAULT_PROB_FUNC),typeof(DiffEqBase.DEFAULT_OUTPUT_FUNC),typeof(DiffEqBase.DEFAULT_REDUCTION),Nothing}, ::Tsit5, ::EnsembleThreads; trajectories::Int64, batch_size::Int64, pmap_batch_size::Int64, kwargs::Base.Iterators.Pairs{Symbol,Any,NTuple{5,Symbol},NamedTuple{(:ensemblealg, :saveat, :callback, :adaptive, :dt),Tuple{EnsembleThreads,Float64,DiscreteCallback{DiffEqCallbacks.var"#44#49"{Float64,Base.RefValue{Float64},Base.RefValue{Int64}},DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}},DiffEqCallbacks.var"#46#51"{Bool,DiffEqCallbacks.var"#48#53"{Bool},Float64,Base.RefValue{Float64},Base.RefValue{Int64},DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}}}},Bool,Float64}}}) at /Users/frank/.julia/packages/DiffEqBase/V7P18/src/ensemble/basic_ensemble_solve.jl:112
[9] __solve(::EnsembleProblem{ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,typeof(lotka_volterra!),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},typeof(DiffEqBase.DEFAULT_PROB_FUNC),typeof(DiffEqBase.DEFAULT_OUTPUT_FUNC),typeof(DiffEqBase.DEFAULT_REDUCTION),Nothing}, ::Tsit5; kwargs::Base.Iterators.Pairs{Symbol,Any,NTuple{6,Symbol},NamedTuple{(:ensemblealg, :saveat, :callback, :adaptive, :dt, :trajectories),Tuple{EnsembleThreads,Float64,DiscreteCallback{DiffEqCallbacks.var"#44#49"{Float64,Base.RefValue{Float64},Base.RefValue{Int64}},DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}},DiffEqCallbacks.var"#46#51"{Bool,DiffEqCallbacks.var"#48#53"{Bool},Float64,Base.RefValue{Float64},Base.RefValue{Int64},DiffEqCallbacks.var"#45#50"{var"#affect!#18"{Array{Float32,1}},Float64,Base.RefValue{Float64},Base.RefValue{Int64}}}},Bool,Float64,Int64}}}) at /Users/frank/.julia/packages/DiffEqBase/V7P18/src/ensemble/basic_ensemble_solve.jl:87
[10] #solve#462 at /Users/frank/.julia/packages/DiffEqBase/V7P18/src/solve.jl:128 [inlined]
[11] ensembleloss(::Array{Float32,1}) at /Users/frank/switchdrive/Institution/stochastic_control/ODE_control/threadsafetytest.jl:59
[12] top-level scope at none:1
Updated code based on https://github.com/SciML/DifferentialEquations.jl/issues/646 (moving the callback to the problem type instead of the solve call). If I choose the number of trajectories numtraj >= Threads.nthreads(), some trajectories are different and if numtraj is increased to far, I obtain the error "Tried to add a tstop that is behind the current time.. ". ( ensembleloss(..) and ensembleloss2(..) show the same behaviour.)
# load packages
using Flux, DiffEqFlux, DiffEqSensitivity
using DifferentialEquations. ### version: DifferentialEquations v6.15.0
using Plots
using LinearAlgebra
using Test, Random
function lotka_volterra!(du, u, p, t)
x, y = u
α, β, δ, γ = p
du[1] = dx = α*x - β*x*y
du[2] = dy = -δ*y + γ*x*y
end
# Initial condition
u0 = [1.0, 1.0]
# Simulation interval and intermediary points
tspan = (0.0, 10.0)
dt = 0.1
Random.seed!(10)
nn = Chain(Dense(2, 1, relu))
p_nn, re = Flux.destructure(nn)
pars = [1.5, 1.0, 3.0, 1.0]
function affect!(integrator)
integrator.p[2]=(re(p_nn)(integrator.u))[1]
end
cb = PeriodicCallback(affect!, dt; initial_affect = true, save_positions=(false,false))
prob = ODEProblem{true}(lotka_volterra!, u0, tspan, pars, callback=cb)
sol = solve(prob, Tsit5(), adaptive=true, dt=0.001, saveat=dt)
plot(sol)
@show sum(abs2, sol.-1)
function loss(p; sensealg=ForwardDiffSensitivity())
function affect2!(integrator)
Ω = (re(p)(integrator.u))[1]
integrator.p[2] = Ω
end
cb2 = PeriodicCallback(affect2!,dt;initial_affect=true,save_positions=(false,false))
tmp_prob = remake(prob, callback=cb2)
sol = solve(tmp_prob, Tsit5(), sensealg=sensealg, saveat = dt, adaptive=true, dt=0.001)
loss = sum(abs2, sol.-1)
return loss, sol
end
l1, sol1 = loss(p_nn)
plot(sol1)
function ensembleloss(p; numtraj=5, sensealg=ForwardDiffSensitivity())
function affect3!(integrator)
integrator.p[2]=(re(p)(integrator.u))[1]
end
cb3 = PeriodicCallback(affect3!, dt; initial_affect = true, save_positions=(false,false))
function prob_func(prob,i,repeat)
remake(prob,callback = cb3)
end
ensembleprob = EnsembleProblem(prob,
prob_func = prob_func, safetycopy = true
)
sol = solve(ensembleprob, Tsit5(), ensemblealg=EnsembleThreads(),
sensealg = sensealg,
saveat = dt,
adaptive=true, dt=0.001, trajectories = numtraj)
loss = sum(abs2, sol.-1)/numtraj
return loss, sol
end
l2, sol2 = ensembleloss(p_nn, numtraj=2)
plot!(sol2)
@test isapprox(l1, l2, atol=1e-10)
function ensembleloss2(p; numtraj=5, sensealg=ForwardDiffSensitivity())
function affect3!(integrator)
integrator.p[2]=(re(p)(integrator.u))[1]
end
cb3 = PeriodicCallback(affect3!, dt; initial_affect = true, save_positions=(false,false))
tmp_prob = remake(prob, callback = cb3)
ensembleprob = EnsembleProblem(tmp_prob)
sol = solve(ensembleprob, Tsit5(), ensemblealg=EnsembleThreads(),
sensealg = sensealg,
saveat = dt,
adaptive=true, dt=0.001, trajectories = numtraj)
loss = sum(abs2, sol.-1)/numtraj
return loss, sol
end
l3, sol3 = ensembleloss2(p_nn, numtraj=10)
plot(sol3)
@test isapprox(l1, l3, atol=1e-10)
@show Threads.nthreads() ### 4