NeuralPDE.jl icon indicating copy to clipboard operation
NeuralPDE.jl copied to clipboard

Error in solution representation when using `Lux.BatchNorm`

Open NoFishLikeIan opened this issue 1 year ago • 1 comments

Evaluating the solution of an optimization procedure using

discretization = PhysicsInformedNN(chain, QuadratureTraining())
prob = discretize(system, discretization)

result = Optimization.solve(prob, BFGS())

discretization.phi([0.5], result.u) # Error here

yields the error

BoundsError: attempt to access Tuple{Int64} at index [0]

if the Lux.chain contains a Lux.BatchNorm layer.

Here is a MWE to reproduce the issue,

using Lux: Chain, Dense, BatchNorm, relu
using NeuralPDE
using Optimization, OptimizationOptimJL
using ModelingToolkit: Interval, infimum, supremum

n = 18
chain = Chain(
    Dense(1, n),
    BatchNorm(n, relu), # Without this line there is no error
    Dense(n, 1)
)

@parameters t
@variables f(..)

D = Differential(t)

diffeq = [ f(t) ~ D(f(t)) ]
bcs = [ f(0) ~ 1 ]

tdomain = t ∈ Interval(0, 1)

@named system = PDESystem(diffeq, bcs, [tdomain], [t], [f(t)])

discretization = PhysicsInformedNN(chain, QuadratureTraining())
prob = discretize(system, discretization)

result = Optimization.solve(prob, BFGS())

discretization.phi([0.5], result.u) # Error occurs here

The Pkg.status() yields

⌅ [b2108857] Lux v0.4.58
  [961ee093] ModelingToolkit v8.63.0
  [315f7962] NeuralPDE v5.7.0
  [7f7a1694] Optimization v3.15.2
  [36348300] OptimizationOptimJL v0.1.9

A full Stacktrace can be found here.

NoFishLikeIan avatar Jul 28 '23 15:07 NoFishLikeIan

After trying around a bit more it might be an issue with how Lux.apply expects Matrix rather than Vector when a BatchNorm layer is in the chain. This other MWE makes it clear

using Lux, Random

rng = Random.default_rng()
Random.seed!(rng, 0)

xvec = rand(rng, Float32, 1) # 1-element Vector{Float32}
xmatrix = rand(rng, Float32, 1, 1) # 1×1 Matrix{Float32}


model = Chain(
    Dense(1, 20, tanh), 
    Chain(Dense(20, 1, tanh), 
    Dense(1, 10))
)
ps, st = Lux.setup(rng, model)

Lux.apply(model, xmatrix, ps, st) # This works
Lux.apply(model, xvec, ps, st) # This works

norm_model = Chain(
    Dense(1, 20, tanh), 
    BatchNorm(20, relu),
    Chain(Dense(20, 1, tanh), 
    Dense(1, 10))
)

norm_ps, norm_st = Lux.setup(rng, norm_model)

Lux.apply(norm_model, xmatrix, norm_ps, norm_st) # This works
Lux.apply(norm_model, xvec, norm_ps, norm_st) # This gives the error

Should the issue be transferred to Lux or should NeuralPDE internally transform things to a Matrix?

NoFishLikeIan avatar Jul 28 '23 16:07 NoFishLikeIan