NeuralPDE.jl icon indicating copy to clipboard operation
NeuralPDE.jl copied to clipboard

NN ODE loss gets squared

Open milzj opened this issue 1 year ago • 1 comments

I think that the objective functions/loss functions used for GridTraining and StochasticTraining generated in ode_solve.jl "get squared" before they are passed to OptimizationFunction, meaning, the function $(\sum_i \ell_i^2)^2$ is passed to OptimizationFunction instead of $\sum_i \ell_i^2$. This is a result of using sum(abs2, ...) in ode_solve.jl#L264, for example. I am unsure of whether this is desired.

Here some code.

import Random, NeuralPDE
import Lux, OptimizationOptimisers


rhs(u,p,t) = cos(2pi*t)
tspan = (0.0f0, 1.0f0)
u0 = 0.0f0
dt = 1/20f0
prob = NeuralPDE.ODEProblem(rhs, u0 ,tspan)
chain = Lux.Chain(Lux.Dense(1, 5, Lux.σ), Lux.Dense(5, 1))

opt =  OptimizationOptimisers.Adam(0.01)
θ, st = Lux.setup(Random.seed!(1234), chain) 
init_params = NeuralPDE.ComponentArrays.ComponentArray(θ)
strategy = NeuralPDE.GridTraining(dt)
alg = NeuralPDE.NNODE(chain, opt, init_params, strategy=strategy, batch = true)

# Perform one iteration to get access to sol.k.prob
sol = NeuralPDE.solve(prob, alg, maxiters = 1);

# Objective function evaluated at initial point
loss_squared = sol.k.prob.f(init_params, 0)

ts = tspan[1]:(strategy.dx):tspan[2]
phi =  NeuralPDE.ODEPhi(chain, tspan[1], u0, st)

autodiff = false
p = NeuralPDE.SciMLBase.NullParameters()
# Evaluate loss
out = phi(ts, init_params)
fs = reduce(hcat, [rhs(out[i], p, ts[i]) for i in 1:size(out, 2)])
dxdtguess = Array(NeuralPDE.ode_dfdx(phi, ts, init_params, autodiff))
@test sqrt(loss_squared) == sum(abs2, dxdtguess .- fs) / length(ts)


@test loss_squared == sum(abs2, NeuralPDE.inner_loss(phi, rhs, autodiff, ts, init_params, p))
@test sqrt(loss_squared) == NeuralPDE.inner_loss(phi, rhs, autodiff, ts, init_params, p)

milzj avatar Sep 20 '22 19:09 milzj

Yeah that's not desired, but it shouldn't actually change the loss function. It would be good to figure out why it's squared instead of square rooted though: something got inverted.

ChrisRackauckas avatar Sep 24 '22 19:09 ChrisRackauckas