NeuralPDE.jl
NeuralPDE.jl copied to clipboard
Rewrite NeuralPDE
If anyone wants to rewrite NeuralPDE, I want to share some of my expectations here, some of which have been reflected in Sophon.jl.
The following code is mostly pseudocode, used only to illustrate the concept. I will gradually add what I can think of.
On PhysicsInformedNN
A PINN should be created and accessed in the form of named tuples, just like Lux.Chain.
julia> chain=Chain(u = Dense(2,16), v = Dense(2,16))
Chain(
u = Dense(2 => 16), # 48 parameters
v = Dense(2 => 16), # 48 parameters
) # Total: 96 parameters,
# plus 0 states, summarysize 32 bytes.
Then wrap it in a PINN structure. It can be accessed through getindex, but getproperty is preferred.
pinn = PINN(chain)
pinn.u = chain.layers.u # Something like this, but also need to handle state
Now we no longer need to record the order of dependent variables in the PDESystem. Parsing thus becomes simpler.
# u(t,x) + v(t,x) lowers to
begin
phi_u = pinn.u
phi_v = pinn.v
end
begin
θ_u = θ.u
θ_v = θ.v
end
phi_u(coord_u, θ_u) .+ phi_v(coord_v, θ_v)
The code generating begin ... end should be wrapped in a reusable function, and so is u(t,x)->phi_u(coord_u, θ). Perhaps @rule can play a role here, I'm not sure. We need to break down transform_expressions into small functions or rewrite rules
Single output & multioutput
I think we can just change
phi_u=phi.u, θ_u = θ.u
to
phi_u=phi, θ_u = θ
there is no need to make a difference in parsing other than that.
On coordinate
The current parsing is like this
# u(t,x), v(t)
(coord, θ)->
let x, t = coord[[1],:], coord[[2],:]
coord1 = vcat(x,t)
coord2 = vcat(t)
# computation...
end
Here unnecessary memory allocation caused by vcat has occurred. I hope there is a get_coord function that performs the least amount of vcat.
function get_coord(u)
if arguments(u) == all_indvars
return :(coord)
else if length(arguments(u)) == 1
return arguments(u)[1] # say :t
else
return :(vcat($(arguments(u)...)))
end
end
Similarly, we prefer to use named suffixes.
# u(t,x), v(t)
(coord, θ)->
let (x, t) = (coord[[1],:], coord[[2],:])
begin
coord_u = coord
coord_v = t
end
# ...
end
Periodic boundary conditions
Periodic boundary conditions are not currently handled correctly, see #469. Because the same dependent variable has different inputs. It needs to be treated specially.
A more general approach is to parse expressions like u(1.0,t) into
(coord, θ)->
begin
phi_u = pinn.u
end
begin
θ_u = θ.u
end
let (x, t) = (coord[[1], :], coord[[2], :])
begin
coord_u = coord
end
phi_u(Base.Fix1(fill, 1.0)(size(x)), θ_u)
end
In this way, periodic conditions are naturally parsed correctly.
# u(-1.0, t) ~ u(1.0, t) lowers to
begin
phi_u = pinn.u
end
begin
θ_u = θ.u
end
let (x, t) = (coord[[1], :], coord[[2], :])
begin
coord_u = coord
end
phi_u(vcat(Base.Fix1(fill, -1.0)(size(x)), t)) .- phi_u(Base.Fix1(fill, 1.0)(size(x)), t))
end
Even if something like this appears in the equation:
u(x,t)+u(1.0,t)+v(x,1.0)
It can also be correctly parsed.
On derivative
I imagine having the following parsing function:
expr = Dxxx(Dtt(u(t,x)))
function get_directions(expr)
# some code
return [2,2,2,1,1] # from outermost to innermost
end
function get_derivative(expr)
directions = get_directions(input) # [2,2,2,1,1]
mixed = any(!==(first(directions)), directions) # Is this derivative mixed?
if mixed
# generate an expression here
orders = get_orders(directions) # e.g. [2,2,2,1,1] -> [3,2]
directions = unique(directions) # e.g. [2, 1]
εs = map(get_ε, orders, directions)
# generate an expression here
return quote
finitediff((x,ps)-> finitediff(phi_u, x, ps, ($εs[2]), $(Val(orders[2]))),
coord_u, θ, $(εs[1]), $(Val(orders[1])))
end
else
order = length(directions)
ε = get_ε(first(directions), order)
return :(finitediff(phi_u, coord_u, θ, $ε, $(Val(order))))
end
end
Sample
There should be an independent sample function, which can be used for resampling, and then passed into prob
data = sample(pde_system, sampler)
prob = remake(prob; p = data)
Here length(pde_datasets)==length(eqs), length(boundary_datasets)= length(bcs).
Note that I believe the same data points should be used for all equations, which may help convergence. In any case, it saves memory.
function sample(pde_system, sampler)
pde_dataset = sample(bounds,sampler)
pde_datasets = [pde_dataset for _ in 1:length(pde_system.eqs)]
return [pde_datasets; boundary_datasets]
end
Scalaring
Before scalarize, each equation or boundary condition returns the residuals at each data point, not a scalar. We scalarize only at the very end.
function scalarize(phi, weights::Tuple{Vararg{Function, N}},
loss_functions::Tuple{Vararg{Function, N}}) where N
ex = :(mean($(weights[1])(phi, data, θ) .*
abs2.($(loss_functions[1])(data[1], θ))))
for i in 2:N
ex = :(mean($(weights[i])(phi, data, θ) .*
abs2.($(loss_functions[i])(data[$i], θ))) + $ex)
end
loss_f = :((θ, data) -> $ex)
return eval(loss_f)
end
Note that weights are a tuple of functions, each assigning a weight to each data point in each equation, the most basic case is Returns(1).
Many adaptive algorithms need point-wise residuals, and we can hack weights to achieve changing weights without needing to make additional code changes.
@xtalax can you describe where your parser rewrite is at? I think it would be fine to even do an earlier v6 with some features lost if it gets closer to these goals. One thing I think needs a total rewrite anyways is the integro differential equation support.
I have changed everything in the callstack up to parse equation to use symbolics and have started on parse_equation, I can pivot to this style though
Sophon can generate the symbolic loss functions I want, but it also falls short of what I'm talking about here. Its underlying implementation is just as hard to read as NeuralPDE, especially after the introduction of DeepONet. I removed support for integration and default parameters there, which should be kept in NeuralPDE.
What I would expect is to use advanced tools to transform expressions, such as MacroTools.postwalk, which seems like something you're already using.
I'm happy with the design of Sophon's interface:
prob = Sophon.discretize(pde_system, pinn, sampler, strategy)
By the way, I don't actually use RuntimeGeneratedFunctions explicitly at all, and everything seems to work fine.