universal_differential_equations
universal_differential_equations copied to clipboard
LV Scenario 2
Incorrect predictions, again....
using OrdinaryDiffEq
using ModelingToolkit
using DataDrivenDiffEq
using LinearAlgebra
using Optimization, OptimizationFlux, OptimizationOptimJL #OptimizationOptimisers for ADAM and OptimizationOptimJL for BFGS
using Lux
using SciMLSensitivity
using Plots
gr()
using JLD2, FileIO
using Statistics
# Set a random seed for reproduceable behaviour
using Random
Random.seed!(2345)
#### NOTE
# Since the recent release of DataDrivenDiffEq v0.6.0 where a complete overhaul of the optimizers took
# place, SR3 has been used. Right now, STLSQ performs better and has been changed.
#### NOTE
# Since the recent release of DataDrivenDiffEq v0.6.0 where a complete overhaul of the optimizers took
# place, SR3 has been used. Right now, STLSQ performs better and has been changed.
# Create a name for saving ( basically a prefix )
svname = "Scenario_2_"
## Data generation
function lotka!(du, u, p, t)
α, β, γ, δ = p
du[1] = α*u[1] - β*u[2]*u[1]
du[2] = γ*u[1]*u[2] - δ*u[2]
end
# Define the experimental parameter
tspan = (0.0f0,6.0f0)
u0 = Float32[0.44249296,4.6280594]
p_ = Float32[1.3, 0.9, 0.8, 1.8]
prob = ODEProblem(lotka!, u0,tspan, p_)
solution = solve(prob, Vern7(), abstol=1e-6, reltol=1e-6, saveat = 0.1)
scatter(solution, alpha = 0.25)
plot!(solution, alpha = 0.5)
# Ideal data
X = Array(solution)
t = solution.t
# Add noise in terms of the mean
x̄ = mean(X, dims = 2)
noise_magnitude = Float32(1e-2)
Xₙ = X .+ (noise_magnitude*x̄) .* randn(eltype(X), size(X))
# Subsample the data in y
# We assume we have only 5 measurements in y, evenly distributed
ty = collect(t[1]:Float32(6/5):t[end])
# Create datasets for the different measurements
round(Int64, mean(diff(ty))/mean(diff(t)))
XS = zeros(eltype(X), length(ty)-1, floor(Int64, mean(diff(ty))/mean(diff(t)))+1) # All x data
TS = zeros(eltype(t), length(ty)-1, floor(Int64, mean(diff(ty))/mean(diff(t)))+1) # Time data
YS = zeros(eltype(X), length(ty)-1, 2) # Just two measurements in y
for i in 1:length(ty)-1
idxs = ty[i].<= t .<= ty[i+1]
XS[i, :] = Xₙ[1, idxs]
TS[i, :] = t[idxs]
YS[i, :] = [Xₙ[2, t .== ty[i]]'; Xₙ[2, t .== ty[i+1]]]
end
XS
scatter!(t, transpose(Xₙ))
## Define the network
# Gaussian RBF as activation
rbf(x) = exp.(-(x.^2))
# Define the network 2->5->5->5->2
U = Lux.Chain(
Lux.Dense(2,5,rbf), Lux.Dense(5,5, rbf), Lux.Dense(5,5, rbf), Lux.Dense(5,2)
)
rng = Random.default_rng()
p1, st = Lux.setup(rng, U)
#for birth, decay parameters -> initializing random values.
parameter_array = Float64[0.5]
p = (layer_1 = p1, layer_2 = parameter_array)
p = Lux.ComponentArray(p)
# Define the hybrid model
function ude_dynamics!(du,u, p, t, p_true)
û = U(u, p.layer_1, st)[1] # Network prediction
du[1] = p_true[1]*u[1] + û[1]
# We assume a linear decay rate for the predator
du[2] = -p.layer_2[1]*u[2] + û[2]
end
p_true = 1.3
# Closure with the known parameter
nn_dynamics!(du,u,p,t) = ude_dynamics!(du,u,p,t,p_true)
# Define the problem
prob_nn = ODEProblem(nn_dynamics!,Xₙ[:, 1], tspan, p)
## Function to train the network
# Define a predictor
function predict(θ, X = Xₙ[:,1], T = t)
Array(solve(prob_nn, Vern7(), u0 = X, p=θ,
saveat = T,
abstol=1e-6, reltol=1e-6,
sensealg = ForwardDiffSensitivity()
))
end
# Multiple shooting like loss
function loss(θ)
# Start with a regularization on the network
l = convert(eltype(θ), 1e-3)*sum(abs2, θ[2:end]) ./ length(θ[2:end])
for i in 1:size(XS,1)
X̂ = predict(θ, [XS[i,1], YS[i,1]], TS[i, :])
# Full prediction in x
l += sum(abs2, XS[i,:] .- X̂[1,:])
# Add the boundary condition in y
l += abs(YS[i, 2] .- X̂[2, end])
end
return l
end
# Container to track the losses
losses = Float32[]
# Callback to show the loss during training
callback(θ,l) = begin
push!(losses, l)
if length(losses)%50==0
println("Current loss after $(length(losses)) iterations: $(losses[end])")
end
false
end
## Training
# First train with ADAM for better convergence -> move the parameters into a
# favourable starting positing for BFGS
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p)->loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, Lux.ComponentArray(p))
res1 = Optimization.solve(optprob, ADAM(0.01f0), callback=callback, maxiters = 300)
println("Training loss after $(length(losses)) iterations: $(losses[end])")
# Train with BFGS to achieve partial fit of the data
optprob2 = remake(optprob,u0 = res1.u)
res2 = Optimization.solve(optprob2, Optim.BFGS(initial_stepnorm=0.01f0), callback=callback, maxiters = 10000, g_tol = 1e-10)
println("Training loss after $(length(losses)) iterations: $(losses[end])")
# Plot the losses
pl_losses = plot(1:300, losses[1:300], yaxis = :log10, xaxis = :log10, xlabel = "Iterations", ylabel = "Loss", label = "ADAM", color = :blue)
plot!(301:length(losses), losses[301:end], yaxis = :log10, xaxis = :log10, xlabel = "Iterations", ylabel = "Loss", label = "BFGS", color = :red)
savefig(pl_losses, "plot_losses.png")
# Rename the best candidate
p_trained = res2.minimizer
## Analysis of the trained network
# Plot the data and the approximation
ts = first(solution.t):mean(diff(solution.t))/2:last(solution.t)
X̂ = predict(p_trained, Xₙ[:, 1], ts)
# Trained on noisy data vs real solution
pl_trajectory = plot(ts, transpose(X̂), ylabel = "t", xlabel ="x(t), y(t)", color = :red, label = ["UDE Approximation" nothing])
scatter!(t, X[1,:], color = :black, label = "Measurements")
ymeasurements = unique!(vcat(YS...))
tmeasurements = unique!(vcat([[ts[1], ts[end]] for ts in eachrow(TS)]...))
scatter!(tmeasurements, ymeasurements, color = :black, label = nothing, legend = :topleft)
savefig(pl_trajectory, "plot_trajectory_reconstruction.png")
# Ideal unknown interactions of the predictor
Ȳ = [-p_[2]*(X̂[1,:].*X̂[2,:])';p_[3]*(X̂[1,:].*X̂[2,:])']
# Neural network guess
Ŷ = U(X̂, p_trained.layer_1, st)[1]
pl_reconstruction = plot(ts, transpose(Ŷ), xlabel = "t", ylabel ="U(x,y)", color = :red, label = ["UDE Approximation" nothing])
plot!(ts, transpose(Ȳ), color = :black, label = ["True Interaction" nothing], legend = :topleft)
savefig(pl_reconstruction, "plot_missingterm_reconstruction.png")
# Plot the error
pl_reconstruction_error = plot(ts, norm.(eachcol(Ȳ-Ŷ)), yaxis = :log, xlabel = "t", ylabel = "L2-Error", color = :red, label = nothing)
pl_missing = plot(pl_reconstruction, pl_reconstruction_error, layout = (2,1))
savefig(pl_missing, "plots_missingterm_reconstruction_and_error.pdf")
pl_overall = plot(pl_trajectory, pl_missing)
savefig(pl_overall, "plots_reconstruction.png")
@RajDandekar
I've adapted scenario 1 - 3 so it should be working in #48
Thaaaanks! How did you find the right format for the setup for this part? Is there documentation on this somewhere or do you just know it?
# Merge the parameters
p = (;δ = rand(rng), ude = p_nn)
p = ComponentVector{Float64}(p)
# Define the hybrid model
function ude_dynamics!(du,u, p, t, p_true)
û = U(u, p.ude, st_nn)[1] # Network prediction
du[1] = p_true[1]*u[1] + û[1]
# We assume a linear decay rate for the predator
du[2] = -p.δ*u[2] + û[2]
end
I've recently switched to Lux and setup some models in the past. NamedTuple
s and ComponentVector
s are really helpful in structuring the overall parameters and both are useable for AD.
Additionally, I found this tutorial to be most insightful. Since I do not need a specific state, I just drop the information.
A more general way would be something along the lines of
mutable struct LuxContainer
model
state
end
(c::LuxContainer)(x, p) = begin
out, state = c.model(x, p, c.state)
c.state = state
return out
end
Thanks, that's helpful. I guess I need to slow down and try to really work step by step through these tutorials instead of scanning for stuff that looks like it might fit. I feel like i'm missing the "why" some things work and others don't...