NeuralPDE.jl icon indicating copy to clipboard operation
NeuralPDE.jl copied to clipboard

Vector output for PINO ODE

Open KirillZubov opened this issue 6 days ago • 0 comments

Implement vector output for PINO ODE

draft test case:

#vector outputs 
@testset "Example ode system: du1 = cos(p * t); du2 = sin(p * t)" begin
    equation = (u, p, t) -> [cos(p[1] * t), sin(p[2] * t)]
    tspan = (0.0f0, 1.0f0)
    u0 = 1.0f0
    prob = ODEProblem(equation, u0, tspan)

    input_branch_size = 2
    deeponet1 = LuxNeuralOperators.DeepONet(
        Chain(
            Dense(input_branch_size => 10, Lux.tanh_fast), Dense(10 => 10, Lux.tanh_fast), Dense(10 => 10)),
        Chain(Dense(1 => 10, Lux.tanh_fast), Dense(10 => 10, Lux.tanh_fast),
            Dense(10 => 10, Lux.tanh_fast)))
    deeponet2 = LuxNeuralOperators.DeepONet(
        Chain(
            Dense(input_branch_size => 10, Lux.tanh_fast), Dense(10 => 10, Lux.tanh_fast), Dense(10 => 10)),
        Chain(Dense(1 => 10, Lux.tanh_fast), Dense(10 => 10, Lux.tanh_fast),
            Dense(10 => 10, Lux.tanh_fast)))

    deeponets = [deeponet1, deeponet1]

    bounds = [(1.0f0, pi), (2.0f0, 3.0f0)]
    number_of_parameters = 50
    strategy = StochasticTraining(40)
    opt = OptimizationOptimisers.Adam(0.03)
    alg = PINOODE(deeponets, opt, bounds, number_of_parameters; strategy = strategy)
    sol = solve(prob, alg, verbose = true, maxiters = 3000)

    function get_trainset(bounds, tspan, number_of_parameters, dt)
        p_ = [range(start = b[1], length = number_of_parameters, stop = b[2])
              for b in bounds]
        p = vcat([collect(reshape(p_i, 1, size(p_i, 1))) for p_i in p_]...)
        t_ = collect(tspan[1]:dt:tspan[2])
        t = collect(reshape(t_, 1, size(t_, 1), 1))
        (p, t)
    end

    ground_solution = (u0, p, t) -> [u0[1] + sin(p * t) / (p), u0[2] - cos(p * t) / (p)]
    function ground_solution_f(p, t)
        reduce(hcat,
            [[ground_solution(u0, p[:, i], t[j]) for j in axes(t, 2)] for i in axes(p, 2)])
    end

    (p, t) = get_trainset(bounds, tspan, 50, 0.025f0)
    ground_solution_ = ground_solution_f(p, t)
    predict = sol.interp((p, t))
    @test ground_solution_≈predict rtol=0.01

    p, t = get_trainset(bounds, tspan, 100, 0.01f0)
    ground_solution_ = ground_solution_f(p, t)
    predict = sol.interp((p, t))
    @test ground_solution_≈predict rtol=0.01
end

ref https://github.com/LuxDL/LuxNeuralOperators.jl/issues/9 https://github.com/SciML/NeuralPDE.jl/pull/806

KirillZubov avatar Jul 02 '24 12:07 KirillZubov