universal_differential_equations
universal_differential_equations copied to clipboard
SEIR Example
Translation of SEIR Example, based on Lotka Volterra 1:
Hiya, ok, here's the first...
- I can't get it to train only for dE, dI, dR the way the old example did, but seems like the way I did it should work (also, it seemed to be linearising the other equations in order to get there back when this did work on the old example, so...)
- It doesn't predict correctly, for either the UDE or NODE. UDE prediction is linear, same as the problem I was having before, NODE is non-linear but incorrect. (Hopefully it's something simple with the ADAM and BFGS setup, because I don't fully understand that part of the code?)
- I don't know the correct code to extrapolate after training (lines 155-158 and 248-251)
- I haven't done the SINDy part of the code yet since the approximation doesn't work.
- Why can't I drop the .jl file in here?
- The savefigs don't work in the Lotka Volterra examples, they can be changed to this format.
cd(@__DIR__)
using Pkg; Pkg.activate("."); Pkg.instantiate()
# Single experiment, move to ensemble further on
# Some good parameter values are stored as comments right now
# because this is really good practice
using OrdinaryDiffEq
using ModelingToolkit
using DataDrivenDiffEq
using LinearAlgebra
using SciMLSensitivity
using Random
using Optimization, OptimizationFlux, OptimizationOptimJL #OptimizationFlux for ADAM and OptimizationOptimJL for BFGS
using Lux
using Statistics
using Plots
gr()
#using DiffEqSensitivity**, Optim**
#using DiffEqFlux**, Flux**
function corona!(du,u,p,t)
S,E,I,R,N,D,C = u
F, β0,α,κ,μ,σ,γ,d,λ = p
dS = -β0*S*F/N - β(t,β0,D,N,κ,α)*S*I/N -μ*S # susceptible
dE = β0*S*F/N + β(t,β0,D,N,κ,α)*S*I/N -(σ+μ)*E # exposed
dI = σ*E - (γ+μ)*I # infected
dR = γ*I - μ*R # removed (recovered + dead)
dN = -μ*N # total population
dD = d*γ*I - λ*D # severe, critical cases, and deaths
dC = σ*E # +cumulative cases
du[1] = dS; du[2] = dE; du[3] = dI; du[4] = dR
du[5] = dN; du[6] = dD; du[7] = dC
end
β(t,β0,D,N,κ,α) = β0*(1-α)*(1-D/N)^κ
S0 = 14e6
u0 = [0.9*S0, 0.0, 0.0, 0.0, S0, 0.0, 0.0]
p_ = [10.0, 0.5944, 0.4239, 1117.3, 0.02, 1/3, 1/5,0.2, 1/11.2]
R0 = p_[2]/p_[7]*p_[6]/(p_[6]+p_[5])
tspan = (0.0, 21.0)
prob = ODEProblem(corona!, u0, tspan, p_)
solution = solve(prob, Vern7(), abstol=1e-12, reltol=1e-12, saveat = 1)
t = solution.t
#[2:4] are Exposed, Infected, Removed
X = Array(solution[2:4,:])'
plot(X)
#Extrapolate to a longer timespan
tspan2 = (0.0,60.0)
prob = ODEProblem(corona!, u0, tspan2, p_)
solution_extrapolate = solve(prob, Vern7(), abstol=1e-12, reltol=1e-12, saveat = 1)
extrapolate = Array(solution_extrapolate[2:4,:])'
plot(extrapolate)
# Ideal data
tsdata = Array(solution)
# Add noise to the data
noisy_data = tsdata + Float32(1e-5)*randn(eltype(tsdata), size(tsdata))
# You can see that the noise looks random
plot(abs.(tsdata-noisy_data)')
### Neural ODE
#Predicts for unknown equations
rng = Random.default_rng()
Random.seed!(111)
#7 inputs for 7 equations, 5 outputs because we know 2 equations already
U = Lux.Chain(Lux.Dense(7, 64, tanh),Lux.Dense(64, 64, tanh), Lux.Dense(64, 64, tanh), Lux.Dense(64, 5))
# Get the initial parameters and state variables of the model
p, st = Lux.setup(rng, U)
function coronaNODE(du,u,p,t,p_)
û = U(u, p, st)[1] # Network prediction
S,E,I,R,N,D,C = u
μ,σ = p_
dS = û[1]
dE = û[2]
dI = û[3]
dR = û[4]
dN = -μ*N # total population
dD = û[5]
dC = σ*E # +cumulative cases
du[1] = dS; du[2] = dE; du[3] = dI; du[4] = dR
du[5] = dN; du[6] = dD; du[7] = dC
end
# Closure with the known parameters
NODE_dynamics!(du,u,p,t) = coronaNODE(du,u,p,t,p_)
# Define the problem
prob_node = ODEProblem(NODE_dynamics!, u0, tspan, p)
## Function to train the network
# Define a predictor
function predict(θ, X = noisy_data[:,1], T = t)
Array(solve(prob_node, Vern7(), u0 = X, p=θ,
saveat = T,
abstol=1e-6, reltol=1e-6,
sensealg = InterpolatingAdjoint(autojacvec=ReverseDiffVJP())))
end
# Simple L2 loss
function loss(θ)
X̂ = predict(θ)
sum(abs2, noisy_data .- X̂)
end
# Container to track the losses
losses = Float32[]
callback = function (p, l)
push!(losses, l)
if length(losses)%50==0
println("Current loss after $(length(losses)) iterations: $(losses[end])")
end
return false
end
## Training
# First train with ADAM for better convergence -> move the parameters into a
# favourable starting positing for BFGS
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p)->loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, Lux.ComponentArray(p))
res1 = Optimization.solve(optprob, ADAM(0.01), callback=callback, maxiters = 200)
println("Training loss after $(length(losses)) iterations: $(losses[end])")
# Train with BFGS
optprob2 = Optimization.OptimizationProblem(optf, res1.minimizer)
res2 = Optimization.solve(optprob2, Optim.BFGS(initial_stepnorm=0.01), callback=callback, maxiters = 10000)
println("Final training loss after $(length(losses)) iterations: $(losses[end])")
# Plot the losses
pl_losses = plot(1:200, losses[1:200], yaxis = :log10, xaxis = :log10, xlabel = "Iterations", ylabel = "Loss", label = "ADAM", color = :blue)
plot!(201:length(losses), losses[201:end], yaxis = :log10, xaxis = :log10, xlabel = "Iterations", ylabel = "Loss", label = "BFGS", color = :red)
savefig(pl_losses, "pl_lossesNODE.png")
# Rename the best candidate
p_trained = res2.minimizer
## Analysis of the trained network
# Plot the data and the approximation
# Make the prediction to match solution.t
X̂ = predict(p_trained, noisy_data[:,1], t)
# Prediction trained on noisy data vs real solution
pl_trajectory = plot(t, transpose(X̂[2:4,:]), xlabel = "t", ylabel ="x(t), y(t)", color = :red, label = ["NODE Approximation" nothing])
scatter!(solution.t, transpose(noisy_data[2:4,:]), color = :black, label = ["Measurements" nothing])
savefig(pl_trajectory, "plots_trajectory_reconstructionNODE.png")
#Extrapolate the solution to match tspan2
ExtrapolateX̂ = predict(p_trained, noisy_data[:,1], solution_extrapolate.t)
extrapolate_trajectory = plot(solution_extrapolate.t, transpose(ExtrapolateX̂[2:4,:]), xlabel = "t", ylabel ="x(t), y(t)", color = :red, label = ["NODE Approximation" nothing])
scatter!(solution_extrapolate.t, transpose(solution_extrapolate[2:4,:]), color = :black, label = ["Measurements" nothing])
savefig(pl_trajectory, "ExtrapolateNODE.png")
### Universal ODE
##Prediction for missing parameters
rng = Random.default_rng()
Random.seed!(222)
#7 inputs for 7 equations, 1 output for 1 missing part of the equation
U = Lux.Chain(Lux.Dense(7, 64, tanh),Lux.Dense(64, 64, tanh), Lux.Dense(64, 64, tanh), Lux.Dense(64, 1))
# Get the initial parameters and state variables of the model
p, st = Lux.setup(rng, U)
function coronaUDE(du,u,p,t,p_true)
û = U(u, p, st)[1] # Network prediction
S,E,I,R,N,D,C = u
F,β0,α,κ,μ,σ,γ,d,λ = p_
dS = -β0*S*F/N - û[1] -μ*S # susceptible
dE = β0*S*F/N + û[1] -(σ+μ)*E # exposed
dI = σ*E - (γ+μ)*I # infected
dR = γ*I - μ*R # removed (recovered + dead)
dN = -μ*N # total population
dD = d*γ*I - λ*D # severe, critical cases, and deaths
dC = σ*E # +cumulative cases
du[1] = dS; du[2] = dE; du[3] = dI; du[4] = dR
du[5] = dN; du[6] = dD; du[7] = dC
end
# Closure with the known parameters
UDE_dynamics!(du,u,p,t) = coronaUDE(du,u,p,t,p_)
# Define the problem
prob_ude = ODEProblem(UDE_dynamics!, u0, tspan, p)
## Function to train the network
# Define a predictor
function predict(θ, X = noisy_data[:,1], T = t)
Array(solve(prob_ude, Vern7(), u0 = X, p=θ,
saveat = T,
abstol=1e-6, reltol=1e-6,
sensealg = InterpolatingAdjoint(autojacvec=ReverseDiffVJP())))
end
# Simple L2 loss
function loss(θ)
X̂ = predict(θ)
sum(abs2, noisy_data .- X̂)
end
# Container to track the losses
losses = Float32[]
callback = function (p, l)
push!(losses, l)
if length(losses)%50==0
println("Current loss after $(length(losses)) iterations: $(losses[end])")
end
return false
end
## Training
# First train with ADAM for better convergence -> move the parameters into a
# favourable starting positing for BFGS
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p)->loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, Lux.ComponentArray(p))
res1UDE = Optimization.solve(optprob, ADAM(0.01), callback=callback, maxiters = 200)
println("Training loss after $(length(losses)) iterations: $(losses[end])")
# Train with BFGS
optprob2 = Optimization.OptimizationProblem(optf, res1.minimizer)
res2UDE = Optimization.solve(optprob2, Optim.BFGS(initial_stepnorm=0.01), callback=callback, maxiters = 10000)
println("Final training loss after $(length(losses)) iterations: $(losses[end])")
# Plot the losses
pl_losses = plot(1:200, losses[1:200], yaxis = :log10, xaxis = :log10, xlabel = "Iterations", ylabel = "Loss", label = "ADAM", color = :blue)
plot!(201:length(losses), losses[201:end], yaxis = :log10, xaxis = :log10, xlabel = "Iterations", ylabel = "Loss", label = "BFGS", color = :red)
savefig(pl_losses, "plot_lossesUDE.png")
# Rename the best candidate
p_trained = res2UDE.minimizer
## Analysis of the trained network
# Plot the data and the approximation
X̂ = predict(p_trained, noisy_data[:,1], t)
# Trained on noisy data vs real solution
pl_trajectory = plot(t, transpose(X̂[2:4,:]), xlabel = "t", ylabel ="x(t), y(t)", color = :red, label = ["UDE Approximation" nothing])
scatter!(solution.t, transpose(solution[2:4,:]), color = :black, label = ["Measurements" nothing])
savefig(pl_trajectory, "plot_trajectory_reconstructionUDE.png")
# Extrapolate out
ExtrapolateX̂ = predict(p_trained, noisy_data[:,1], solution_extrapolate.t)
extrapolate_trajectory = plot(solution_extrapolate.t, transpose(ExtrapolateX̂[2:4,:]), xlabel = "t", ylabel ="x(t), y(t)", color = :red, label = ["UDE Approximation" nothing])
scatter!(solution_extrapolate.t, transpose(solution_extrapolate[2:4,:]), color = :black, label = ["Measurements" nothing])
savefig(extrapolate_trajectory, "ExtrapolateUDE.png")
p.s. Just realised it was unclear on which is universal and which is neural, edited
@rajdandekar
@ccrnn I also have been working on translating the UDE codes into the SciML Sensitivity + Lux interface. Here are the key points based on your prior comment:
(a) The code I have provided below mimics the original code closely.
(b) The plots for both prediction and estimation match the original plots in the paper.
(c) I have not yet done the SINDY part, but will implement it in the coming days.
Can you have a look at the code below and also compare with yours? This may also improve some of the results you are seeing on your end I guess:
using OrdinaryDiffEq
using ModelingToolkit
using DataDrivenDiffEq
using LinearAlgebra
using Lux,Optimization, OptimizationOptimJL, DiffEqFlux, Flux
using Plots
using Random
rng = Random.default_rng()
function corona!(du,u,p,t)
S,E,I,R,N,D,C = u
F, β0,α,κ,μ,σ,γ,d,λ = p
dS = -β0*S*F/N - β(t,β0,D,N,κ,α)*S*I/N -μ*S # susceptible
dE = β0*S*F/N + β(t,β0,D,N,κ,α)*S*I/N -(σ+μ)*E # exposed
dI = σ*E - (γ+μ)*I # infected
dR = γ*I - μ*R # removed (recovered + dead)
dN = -μ*N # total population
dD = d*γ*I - λ*D # severe, critical cases, and deaths
dC = σ*E # +cumulative cases
du[1] = dS; du[2] = dE; du[3] = dI; du[4] = dR
du[5] = dN; du[6] = dD; du[7] = dC
end
β(t,β0,D,N,κ,α) = β0*(1-α)*(1-D/N)^κ
S0 = 14e6
u0 = [0.9*S0, 0.0, 0.0, 0.0, S0, 0.0, 0.0]
p_ = [10.0, 0.5944, 0.4239, 1117.3, 0.02, 1/3, 1/5,0.2, 1/11.2]
R0 = p_[2]/p_[7]*p_[6]/(p_[6]+p_[5])
tspan = (0.0, 21.0)
prob = ODEProblem(corona!, u0, tspan, p_)
solution = solve(prob, Vern7(), abstol=1e-12, reltol=1e-12, saveat = 1)
tspan2 = (0.0,60.0)
prob = ODEProblem(corona!, u0, tspan2, p_)
solution_extrapolate = solve(prob, Vern7(), abstol=1e-12, reltol=1e-12, saveat = 1)
# Ideal data
tsdata = Array(solution)
# Add noise to the data
noisy_data = tsdata + Float32(1e-5)*randn(eltype(tsdata), size(tsdata))
plot(abs.(tsdata-noisy_data)')
### Neural ODE
ann_node = Lux.Chain(Lux.Dense(7, 64, tanh),Lux.Dense(64, 64, tanh), Lux.Dense(64, 64, tanh), Lux.Dense(64, 7))
p1, st1 = Lux.setup(rng, ann_node)
p = Lux.ComponentArray(p1)
function dudt_node(du, u,p,t)
S,E,I,R,N,D,C = u
F,β0,α,κ,μ,σ,γ,d,λ = p_
du[1] = dS = ann_node([S/N,E,I,R,N,D/N,C], p, st1)[1][1]
du[2] = dE = ann_node([S/N,E,I,R,N,D/N,C], p, st1)[1][2]
du[3] = dI = ann_node([S/N,E,I,R,N,D/N,C], p, st1)[1][3]
du[4] = dR = ann_node([S/N,E,I,R,N,D/N,C], p, st1)[1][4]
du[5] = dD = ann_node([S/N,E,I,R,N,D/N,C], p, st1)[1][5]
du[6] = dN = -μ*N # total population
du[7] = dC = σ*E # +cumulative cases
[dS,dE,dI,dR,dN,dD,dC]
end
prob_node = ODEProblem{true}(dudt_node, u0, tspan)
function predict(θ)
x = Array(solve(prob_node, Tsit5(),p = θ, saveat = 1,abstol=1e-6, reltol=1e-6,
sensealg = InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true))))
end
# No regularisation right now
function loss(θ)
pred = predict(θ)
loss = sum(abs2, (noisy_data[2:4,:] .- pred[2:4,:]))
return loss # + 1e-5*sum(sum.(abs, params(ann)))
end
loss(p)
iter = 0
function callback(θ,l)
global iter
iter += 1
if iter%10 == 0
println(l)
end
return false
end
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, p)
res1 = Optimization.solve(optprob, ADAM(0.0001), callback = callback, maxiters = 1500)
optprob2 = remake(optprob,u0 = res1.u)
res2 = Optimization.solve(optprob2,Optim.BFGS(initial_stepnorm=0.01),
callback=callback,
maxiters = 10000)
data_pred = predict(res2.u)
scatter(solution, vars=[2,3,4], label=["True Exposed" "True Infected" "True Recovered"])
plot!(data_pred[2,:], label=["Estimated Exposed"])
plot!(data_pred[3,:], label=["Estimated Infected" ])
plot!(data_pred[4,:], label=["Estimated Recovered"])
# Plot the losses
# TO DO: plot(losses, yaxis = :log, xaxis = :log, xlabel = "Iterations", ylabel = "Loss")
# Extrapolate out
prob_node_extrapolate = ODEProblem{true}(dudt_node, u0, tspan2)
_sol_node = Array(solve(prob_node_extrapolate, Tsit5(),p = res2.u, saveat = 1,abstol=1e-12, reltol=1e-12,
sensealg = InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true))))
p_node = scatter(solution_extrapolate, vars=[2,3,4], legend = :topleft, label=["True Exposed" "True Infected" "True Recovered"], title="Neural ODE Extrapolation")
plot!(p_node,_sol_node[2,:], lw = 5, label=["Estimated Exposed"])
plot!(p_node,_sol_node[3,:], lw = 5, label=["Estimated Infected" ])
plot!(p_node,_sol_node[4,:], lw = 5, label=["Estimated Recovered"])
plot!(p_node,[20.99,21.01],[0.0,maximum(hcat(Array(solution_extrapolate[2:4,:]),Array(_sol_node[2:4,:])))],lw=5,color=:black,label="Training Data End")
savefig("neuralode_extrapolation.png")
savefig("neuralode_extrapolation.pdf")
### Universal ODE Part 1
ann = Lux.Chain(Lux.Dense(3, 64, tanh),Lux.Dense(64, 64, tanh), Lux.Dense(64, 1))
p1, st1 = Lux.setup(rng, ann)
p = Lux.ComponentArray(p1)
function dudt_(du, u,p,t)
S,E,I,R,N,D,C = u
F, β0,α,κ,μ,σ,γ,d,λ = p_
z = ann([S/N,I,D/N], p, st1)[1][1]
du[1] = dS = -β0*S*F/N - z[1] -μ*S # susceptible
du[2] = dE = β0*S*F/N + z[1] -(σ+μ)*E # exposed
du[3] = dI = σ*E - (γ+μ)*I # infected
du[4] = dR = γ*I - μ*R # removed (recovered + dead)
du[5] = dN = -μ*N # total population
du[6] = dD = d*γ*I - λ*D # severe, critical cases, and deaths
du[7] = dC = σ*E # +cumulative cases
end
prob_nn = ODEProblem{true}(dudt_,u0, tspan)
function predict(θ)
x = Array(solve(prob_nn, Tsit5(),p = θ, saveat = solution.t,abstol=1e-6, reltol=1e-6,
sensealg = InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true))))
end
# No regularisation right now
function loss(θ)
pred = predict(θ)
loss = sum(abs2, (noisy_data[2:4,:] .- pred[2:4,:]))
return loss # + 1e-5*sum(sum.(abs, params(ann)))
end
loss(p)
iter = 0
function callback(θ,l)
global iter
iter += 1
if iter%50 == 0
println(l)
end
return false
end
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, p)
res1 = Optimization.solve(optprob, ADAM(0.01), callback = callback, maxiters = 500)
optprob2 = remake(optprob,u0 = res1.u)
res2 = Optimization.solve(optprob2,Optim.BFGS(initial_stepnorm=0.01),
callback=callback,
maxiters = 550)
uode_sol = predict(res2.u)
scatter(solution, vars=[2,3,4], label=["True Exposed" "True Infected" "True Recovered"])
plot!(uode_sol[2,:], label=["Estimated Exposed"])
plot!(uode_sol[3,:], label=["Estimated Infected" ])
plot!(uode_sol[4,:], label=["Estimated Recovered"])
# Plot the losses
#TO DO: plot(losses, yaxis = :log, xaxis = :log, xlabel = "Iterations", ylabel = "Loss")
# Collect the state trajectory and the derivatives
#X = noisy_data
# Ideal derivatives
#DX = Array(solution(solution.t, Val{1}))
# Extrapolate out
prob_nn2 = ODEProblem{true}(dudt_, u0, tspan2)
_sol_uode = Array(solve(prob_nn2, Tsit5(),p = res2.u, saveat = 1,abstol=1e-12, reltol=1e-12,
sensealg = InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true))))
p_uode = scatter(solution_extrapolate, vars=[2,3,4], legend = :topleft, label=["True Exposed" "True Infected" "True Recovered"], title="Neural ODE Extrapolation")
plot!(p_uode,_sol_uode[2,:], lw = 5, label=["Estimated Exposed"])
plot!(p_uode,_sol_uode[3,:], lw = 5, label=["Estimated Infected" ])
plot!(p_uode,_sol_uode[4,:], lw = 5, label=["Estimated Recovered"])
plot!(p_uode,[20.99,21.01],[0.0,maximum(hcat(Array(solution_extrapolate[2:4,:]),Array(_sol_uode[2:4,:])))],lw=5,color=:black,label="Training Data End")
savefig("universalode_extrapolation.png")
savefig("universalode_extrapolation.pdf")
Thanks for this - how did you find the right form for the [1][1], [1][2], etc? I was trying to find this! With the component array too. What exactly does the first [1] do?
Not being able to predict for [2:4] was something weird with having u0 in the predict function.
I am still seeing linear approximations for the first example, and incorrect non-linear approximations for the second, with your code too though?
@RajDandekar
@ccrnn: he [1] basically prints out the vector of 5 elements. Then we need to access each element separately through 1 more level of indexing..
Regarding your second question, even in Chris's original code, the Neural ODE and the UDE extrapolations are not good..
See this: https://github.com/ChrisRackauckas/universal_differential_equations/blob/master/SEIR_exposure/neuralode_extrapolation.png
and this: https://github.com/ChrisRackauckas/universal_differential_equations/blob/master/SEIR_exposure/universalode_extrapolation.png
For now, it's good that we match those results with SciML Sensitivity. We can indeed match the results.
We can spend some time later to maybe optimize the code hyperparameters etc to get better results.