Enzyme.jl
Enzyme.jl copied to clipboard
Segfault during Enzyme.jl AD in OptimizationFunction"
I was trying to build a physics-informed neural network model, here is the completed program code:
using Lux, Optimization, OptimizationOptimJL, Enzyme, ComponentArrays
using Plots, Random, ProgressMeter, Statistics
Enzyme.Compiler.VERBOSE_ERRORS[] = true
# Constants
const U_lid = 1.0f0
const ν = 0.01f0
const Re = 100.0f0
# Data generation
function generate_data()
N_f = 10_000
x_f = rand(Float32, N_f)
y_f = rand(Float32, N_f)
N_b = 500
x_left = zeros(Float32, N_b)
y_left = rand(Float32, N_b)
x_right = ones(Float32, N_b)
y_right = rand(Float32, N_b)
x_bottom = rand(Float32, N_b)
y_bottom = zeros(Float32, N_b)
x_top = rand(Float32, N_b)
y_top = ones(Float32, N_b)
X = vcat(
hcat(x_f, y_f),
hcat(x_left, y_left),
hcat(x_right, y_right),
hcat(x_bottom, y_bottom),
hcat(x_top, y_top)
)
u_b = vcat(zeros(Float32, 3*N_b), ones(Float32, N_b))
v_b = zeros(Float32, 4*N_b)
return X, u_b, v_b
end
# Model
function build_pinn()
Lux.Chain(
Lux.Dense(2 => 20, tanh),
Lux.Dense(20 => 20, tanh),
Lux.Dense(20 => 20, tanh),
Lux.Dense(20 => 20, tanh),
Lux.Dense(20 => 20, tanh),
Lux.Dense(20 => 20, tanh),
Lux.Dense(20 => 20, tanh),
Lux.Dense(20 => 20, tanh),
Lux.Dense(20 => 3)
)
end
# Physics loss
function compute_physics_loss(model, ps, st, x)
h = 1f-4
N = size(x, 2)
loss = 0.0f0
for i in 1:N
xy = x[:, i]
output = model(xy, ps, st)[1]
u, v, p = output
# First derivatives (manual finite diff for compatibility with Enzyme)
u_x = (model(xy .+ Float32[h, 0f0], ps, st)[1][1] - model(xy .- Float32[h, 0f0], ps, st)[1][1]) / (2h)
u_y = (model(xy .+ Float32[0f0, h], ps, st)[1][1] - model(xy .- Float32[0f0, h], ps, st)[1][1]) / (2h)
v_x = (model(xy .+ Float32[h, 0f0], ps, st)[1][2] - model(xy .- Float32[h, 0f0], ps, st)[1][2]) / (2h)
v_y = (model(xy .+ Float32[0f0, h], ps, st)[1][2] - model(xy .- Float32[0f0, h], ps, st)[1][2]) / (2h)
p_x = (model(xy .+ Float32[h, 0f0], ps, st)[1][3] - model(xy .- Float32[h, 0f0], ps, st)[1][3]) / (2h)
p_y = (model(xy .+ Float32[0f0, h], ps, st)[1][3] - model(xy .- Float32[0f0, h], ps, st)[1][3]) / (2h)
# Second derivatives
u_xx = (model(xy .+ Float32[h, 0f0], ps, st)[1][1] - 2u + model(xy .- Float32[h, 0f0], ps, st)[1][1]) / h^2
u_yy = (model(xy .+ Float32[0f0, h], ps, st)[1][1] - 2u + model(xy .- Float32[0f0, h], ps, st)[1][1]) / h^2
v_xx = (model(xy .+ Float32[h, 0f0], ps, st)[1][2] - 2v + model(xy .- Float32[h, 0f0], ps, st)[1][2]) / h^2
v_yy = (model(xy .+ Float32[0f0, h], ps, st)[1][2] - 2v + model(xy .- Float32[0f0, h], ps, st)[1][2]) / h^2
# Physics
continuity = u_x + v_y
momentum_x = u*u_x + v*u_y + p_x - ν*(u_xx + u_yy)
momentum_y = u*v_x + v*v_y + p_y - ν*(v_xx + v_yy)
loss += continuity^2 + momentum_x^2 + momentum_y^2
end
return loss / N
end
# Boundary loss
function compute_boundary_loss(model, ps, st, x_b, u_b, v_b)
loss = 0.0f0
N = size(x_b, 2)
for i in 1:N
xy = x_b[:, i]
out = model(xy, ps, st)[1]
u_pred, v_pred = out[1], out[2]
loss += (u_pred - u_b[i])^2 + (v_pred - v_b[i])^2
end
return loss / N
end
# Training
function train_pinn(;epochs=10_000, lr=1e-3)
X, u_b, v_b = generate_data()
X_f = X[1:10_000, :]'
X_b = X[10_001:end, :]'
u_b = u_b[10_001-10_000:end]
v_b = v_b[10_001-10_000:end]
rng = Random.MersenneTwister(1234)
model = build_pinn()
ps, st = Lux.setup(rng, model)
ps = ComponentArray(ps)
function loss_function(p, _)
idx_f = rand(1:size(X_f, 2), min(1024, size(X_f, 2)))
idx_b = rand(1:size(X_b, 2), min(256, size(X_b, 2)))
physics_loss = compute_physics_loss(model, p, st, X_f[:, idx_f])
boundary_loss = compute_boundary_loss(model, p, st, X_b[:, idx_b], u_b[idx_b], v_b[idx_b])
return physics_loss + boundary_loss
end
optf = OptimizationFunction(loss_function, Optimization.AutoEnzyme())
optprob = OptimizationProblem(optf, ps)
opt = OptimizationOptimJL.LBFGS()
history = Float32[loss_function(ps, nothing)]
prog = Progress(epochs, 1)
for epoch in 1:epochs
res = solve(optprob, opt, maxiters=1)
ps = res.u
optprob = OptimizationProblem(optf, ps)
current_loss = loss_function(ps, nothing)
history = vcat(history, current_loss)
if epoch % 100 == 0
ProgressMeter.update!(prog, epoch; showvalues=[
(:Epoch, epoch),
(:Loss, current_loss)
])
end
end
return model, ps, st, history
end
# Visualization
function plot_results(model, ps, st)
nx, ny = 50, 50
x = range(0f0, 1f0, length=nx)
y = range(0f0, 1f0, length=ny)
u = [model([x[i], y[j]], ps, st)[1][1] for i in 1:nx, j in 1:ny]
v = [model([x[i], y[j]], ps, st)[1][2] for i in 1:nx, j in 1:ny]
p = [model([x[i], y[j]], ps, st)[1][3] for i in 1:nx, j in 1:ny]
plt_str = plot(title="Velocity Field Streamlines", xlabel="x", ylabel="y")
streamplot!(plt_str, x, y, u', v', color=:black, linewidth=1, arrowsize=1, density=2)
vel_mag = sqrt.(u.^2 .+ v.^2)
plt_mag = heatmap(x, y, vel_mag', title="Velocity Magnitude", xlabel="x", ylabel="y", c=:viridis)
plt_p = heatmap(x, y, p', title="Pressure Field", xlabel="x", ylabel="y", c=:viridis)
plot(plt_str, plt_mag, plt_p, layout=(1,3), size=(1200,400))
end
# Main
model, ps, st, history = train_pinn(epochs=10_000)
plot_results(model, ps, st)
plot(history, yscale=:log10, label="Training Loss", xlabel="Epoch", ylabel="Loss (log scale)")
The error message is:
Enzyme compilation failed due to an internal error.
Please open an issue with the code to reproduce and full error log on github.com/EnzymeAD/Enzyme.jl
To toggle more information for debugging (needed for bug reports), set Enzyme.Compiler.VERBOSE_ERRORS[] = true (default false)
Current scope:
Here is my working environment specification:
IDE: Visual Studio Code, Jupyter Notebook mode
OS: Debian 12 (kernel version 6.1.0-34-amd64 (64-bit))
Julia version: 1.11.5
Enzyme version: v0.13.43
Working on CPU
Processor: 4 × Intel® Core™ i5-6300U CPU @ 2.40GHz
Memory: 30.6 GiB of RAM
Manufacturer: LENOVO
Let me know if you need any clarification, and thanks for all the amazing work on Enzyme!
Could you try with Julia 1.10?
we just released a fix for a similar issue on 1.11, @muis-code can you retry?
gentle ping @muis-code
on 1.10 this seems to work successfully without issue.
on 1.11 the illegal replacement is gone, but now it hits a segfault