NeuralPDE.jl copied to clipboard
Error in solution representation when using `Lux.BatchNorm`
Evaluating the solution of an optimization procedure using
discretization = PhysicsInformedNN(chain, QuadratureTraining())
prob = discretize(system, discretization)
result = Optimization.solve(prob, BFGS())
discretization.phi([0.5], result.u) # Error here
yields the error
BoundsError: attempt to access Tuple{Int64} at index [0]
if the Lux.chain
contains a Lux.BatchNorm
Here is a MWE to reproduce the issue,
using Lux: Chain, Dense, BatchNorm, relu
using NeuralPDE
using Optimization, OptimizationOptimJL
using ModelingToolkit: Interval, infimum, supremum
n = 18
chain = Chain(
Dense(1, n),
BatchNorm(n, relu), # Without this line there is no error
Dense(n, 1)
@parameters t
@variables f(..)
D = Differential(t)
diffeq = [ f(t) ~ D(f(t)) ]
bcs = [ f(0) ~ 1 ]
tdomain = t ∈ Interval(0, 1)
@named system = PDESystem(diffeq, bcs, [tdomain], [t], [f(t)])
discretization = PhysicsInformedNN(chain, QuadratureTraining())
prob = discretize(system, discretization)
result = Optimization.solve(prob, BFGS())
discretization.phi([0.5], result.u) # Error occurs here
The Pkg.status()
⌅ [b2108857] Lux v0.4.58
[961ee093] ModelingToolkit v8.63.0
[315f7962] NeuralPDE v5.7.0
[7f7a1694] Optimization v3.15.2
[36348300] OptimizationOptimJL v0.1.9
A full Stacktrace
can be found here.
After trying around a bit more it might be an issue with how Lux.apply
expects Matrix
rather than Vector
when a BatchNorm
layer is in the chain. This other MWE makes it clear
using Lux, Random
rng = Random.default_rng()
Random.seed!(rng, 0)
xvec = rand(rng, Float32, 1) # 1-element Vector{Float32}
xmatrix = rand(rng, Float32, 1, 1) # 1×1 Matrix{Float32}
model = Chain(
Dense(1, 20, tanh),
Chain(Dense(20, 1, tanh),
Dense(1, 10))
ps, st = Lux.setup(rng, model)
Lux.apply(model, xmatrix, ps, st) # This works
Lux.apply(model, xvec, ps, st) # This works
norm_model = Chain(
Dense(1, 20, tanh),
BatchNorm(20, relu),
Chain(Dense(20, 1, tanh),
Dense(1, 10))
norm_ps, norm_st = Lux.setup(rng, norm_model)
Lux.apply(norm_model, xmatrix, norm_ps, norm_st) # This works
Lux.apply(norm_model, xvec, norm_ps, norm_st) # This gives the error
Should the issue be transferred to Lux
or should NeuralPDE
internally transform things to a Matrix