NeuralPDE.jl
NeuralPDE.jl copied to clipboard
Vector output for PINO ODE
Implement vector output for PINO ODE
draft test case:
#vector outputs
@testset "Example ode system: du1 = cos(p * t); du2 = sin(p * t)" begin
equation = (u, p, t) -> [cos(p[1] * t), sin(p[2] * t)]
tspan = (0.0f0, 1.0f0)
u0 = 1.0f0
prob = ODEProblem(equation, u0, tspan)
input_branch_size = 2
deeponet1 = LuxNeuralOperators.DeepONet(
Chain(
Dense(input_branch_size => 10, Lux.tanh_fast), Dense(10 => 10, Lux.tanh_fast), Dense(10 => 10)),
Chain(Dense(1 => 10, Lux.tanh_fast), Dense(10 => 10, Lux.tanh_fast),
Dense(10 => 10, Lux.tanh_fast)))
deeponet2 = LuxNeuralOperators.DeepONet(
Chain(
Dense(input_branch_size => 10, Lux.tanh_fast), Dense(10 => 10, Lux.tanh_fast), Dense(10 => 10)),
Chain(Dense(1 => 10, Lux.tanh_fast), Dense(10 => 10, Lux.tanh_fast),
Dense(10 => 10, Lux.tanh_fast)))
deeponets = [deeponet1, deeponet1]
bounds = [(1.0f0, pi), (2.0f0, 3.0f0)]
number_of_parameters = 50
strategy = StochasticTraining(40)
opt = OptimizationOptimisers.Adam(0.03)
alg = PINOODE(deeponets, opt, bounds, number_of_parameters; strategy = strategy)
sol = solve(prob, alg, verbose = true, maxiters = 3000)
function get_trainset(bounds, tspan, number_of_parameters, dt)
p_ = [range(start = b[1], length = number_of_parameters, stop = b[2])
for b in bounds]
p = vcat([collect(reshape(p_i, 1, size(p_i, 1))) for p_i in p_]...)
t_ = collect(tspan[1]:dt:tspan[2])
t = collect(reshape(t_, 1, size(t_, 1), 1))
(p, t)
end
ground_solution = (u0, p, t) -> [u0[1] + sin(p * t) / (p), u0[2] - cos(p * t) / (p)]
function ground_solution_f(p, t)
reduce(hcat,
[[ground_solution(u0, p[:, i], t[j]) for j in axes(t, 2)] for i in axes(p, 2)])
end
(p, t) = get_trainset(bounds, tspan, 50, 0.025f0)
ground_solution_ = ground_solution_f(p, t)
predict = sol.interp((p, t))
@test ground_solution_≈predict rtol=0.01
p, t = get_trainset(bounds, tspan, 100, 0.01f0)
ground_solution_ = ground_solution_f(p, t)
predict = sol.interp((p, t))
@test ground_solution_≈predict rtol=0.01
end
ref https://github.com/LuxDL/LuxNeuralOperators.jl/issues/9 https://github.com/SciML/NeuralPDE.jl/pull/806