Enzyme.jl icon indicating copy to clipboard operation
Enzyme.jl copied to clipboard

Segfault during Enzyme.jl AD in OptimizationFunction"

Open muis-code opened this issue 7 months ago • 5 comments

I was trying to build a physics-informed neural network model, here is the completed program code:


using Lux, Optimization, OptimizationOptimJL, Enzyme, ComponentArrays
using Plots, Random, ProgressMeter, Statistics

Enzyme.Compiler.VERBOSE_ERRORS[] = true

# Constants
const U_lid = 1.0f0
const ν = 0.01f0
const Re = 100.0f0

# Data generation
function generate_data()
    N_f = 10_000
    x_f = rand(Float32, N_f)
    y_f = rand(Float32, N_f)
    
    N_b = 500
    x_left = zeros(Float32, N_b)
    y_left = rand(Float32, N_b)
    x_right = ones(Float32, N_b)
    y_right = rand(Float32, N_b)
    x_bottom = rand(Float32, N_b)
    y_bottom = zeros(Float32, N_b)
    x_top = rand(Float32, N_b)
    y_top = ones(Float32, N_b)

    X = vcat(
        hcat(x_f, y_f),
        hcat(x_left, y_left),
        hcat(x_right, y_right),
        hcat(x_bottom, y_bottom),
        hcat(x_top, y_top)
    )

    u_b = vcat(zeros(Float32, 3*N_b), ones(Float32, N_b))
    v_b = zeros(Float32, 4*N_b)

    return X, u_b, v_b
end

# Model
function build_pinn()
    Lux.Chain(
        Lux.Dense(2 => 20, tanh),
        Lux.Dense(20 => 20, tanh),
        Lux.Dense(20 => 20, tanh),
        Lux.Dense(20 => 20, tanh),
        Lux.Dense(20 => 20, tanh),
        Lux.Dense(20 => 20, tanh),
        Lux.Dense(20 => 20, tanh),
        Lux.Dense(20 => 20, tanh),
        Lux.Dense(20 => 3)
    )
end

# Physics loss
function compute_physics_loss(model, ps, st, x)
    h = 1f-4
    N = size(x, 2)
    loss = 0.0f0

    for i in 1:N
        xy = x[:, i]
        output = model(xy, ps, st)[1]
        u, v, p = output

        # First derivatives (manual finite diff for compatibility with Enzyme)
        u_x = (model(xy .+ Float32[h, 0f0], ps, st)[1][1] - model(xy .- Float32[h, 0f0], ps, st)[1][1]) / (2h)
        u_y = (model(xy .+ Float32[0f0, h], ps, st)[1][1] - model(xy .- Float32[0f0, h], ps, st)[1][1]) / (2h)
        v_x = (model(xy .+ Float32[h, 0f0], ps, st)[1][2] - model(xy .- Float32[h, 0f0], ps, st)[1][2]) / (2h)
        v_y = (model(xy .+ Float32[0f0, h], ps, st)[1][2] - model(xy .- Float32[0f0, h], ps, st)[1][2]) / (2h)
        p_x = (model(xy .+ Float32[h, 0f0], ps, st)[1][3] - model(xy .- Float32[h, 0f0], ps, st)[1][3]) / (2h)
        p_y = (model(xy .+ Float32[0f0, h], ps, st)[1][3] - model(xy .- Float32[0f0, h], ps, st)[1][3]) / (2h)

        # Second derivatives
        u_xx = (model(xy .+ Float32[h, 0f0], ps, st)[1][1] - 2u + model(xy .- Float32[h, 0f0], ps, st)[1][1]) / h^2
        u_yy = (model(xy .+ Float32[0f0, h], ps, st)[1][1] - 2u + model(xy .- Float32[0f0, h], ps, st)[1][1]) / h^2
        v_xx = (model(xy .+ Float32[h, 0f0], ps, st)[1][2] - 2v + model(xy .- Float32[h, 0f0], ps, st)[1][2]) / h^2
        v_yy = (model(xy .+ Float32[0f0, h], ps, st)[1][2] - 2v + model(xy .- Float32[0f0, h], ps, st)[1][2]) / h^2

        # Physics
        continuity = u_x + v_y
        momentum_x = u*u_x + v*u_y + p_x - ν*(u_xx + u_yy)
        momentum_y = u*v_x + v*v_y + p_y - ν*(v_xx + v_yy)

        loss += continuity^2 + momentum_x^2 + momentum_y^2
    end
    return loss / N
end

# Boundary loss
function compute_boundary_loss(model, ps, st, x_b, u_b, v_b)
    loss = 0.0f0
    N = size(x_b, 2)
    for i in 1:N
        xy = x_b[:, i]
        out = model(xy, ps, st)[1]
        u_pred, v_pred = out[1], out[2]
        loss += (u_pred - u_b[i])^2 + (v_pred - v_b[i])^2
    end
    return loss / N
end

# Training
function train_pinn(;epochs=10_000, lr=1e-3)
    X, u_b, v_b = generate_data()
    X_f = X[1:10_000, :]'
    X_b = X[10_001:end, :]'
    u_b = u_b[10_001-10_000:end]
    v_b = v_b[10_001-10_000:end]

    rng = Random.MersenneTwister(1234)
    model = build_pinn()
    ps, st = Lux.setup(rng, model)
    ps = ComponentArray(ps)

    function loss_function(p, _)
        idx_f = rand(1:size(X_f, 2), min(1024, size(X_f, 2)))
        idx_b = rand(1:size(X_b, 2), min(256, size(X_b, 2)))
        
        physics_loss = compute_physics_loss(model, p, st, X_f[:, idx_f])
        boundary_loss = compute_boundary_loss(model, p, st, X_b[:, idx_b], u_b[idx_b], v_b[idx_b])
        return physics_loss + boundary_loss
    end

    optf = OptimizationFunction(loss_function, Optimization.AutoEnzyme())
    optprob = OptimizationProblem(optf, ps)
    opt = OptimizationOptimJL.LBFGS()

    history = Float32[loss_function(ps, nothing)]
    prog = Progress(epochs, 1)

    for epoch in 1:epochs
        res = solve(optprob, opt, maxiters=1)
        ps = res.u
        optprob = OptimizationProblem(optf, ps)

        current_loss = loss_function(ps, nothing)
        history = vcat(history, current_loss)

        if epoch % 100 == 0
            ProgressMeter.update!(prog, epoch; showvalues=[
                (:Epoch, epoch),
                (:Loss, current_loss)
            ])
        end
    end

    return model, ps, st, history
end

# Visualization
function plot_results(model, ps, st)
    nx, ny = 50, 50
    x = range(0f0, 1f0, length=nx)
    y = range(0f0, 1f0, length=ny)

    u = [model([x[i], y[j]], ps, st)[1][1] for i in 1:nx, j in 1:ny]
    v = [model([x[i], y[j]], ps, st)[1][2] for i in 1:nx, j in 1:ny]
    p = [model([x[i], y[j]], ps, st)[1][3] for i in 1:nx, j in 1:ny]

    plt_str = plot(title="Velocity Field Streamlines", xlabel="x", ylabel="y")
    streamplot!(plt_str, x, y, u', v', color=:black, linewidth=1, arrowsize=1, density=2)

    vel_mag = sqrt.(u.^2 .+ v.^2)
    plt_mag = heatmap(x, y, vel_mag', title="Velocity Magnitude", xlabel="x", ylabel="y", c=:viridis)
    plt_p = heatmap(x, y, p', title="Pressure Field", xlabel="x", ylabel="y", c=:viridis)

    plot(plt_str, plt_mag, plt_p, layout=(1,3), size=(1200,400))
end

# Main
model, ps, st, history = train_pinn(epochs=10_000)
plot_results(model, ps, st)
plot(history, yscale=:log10, label="Training Loss", xlabel="Epoch", ylabel="Loss (log scale)")

The error message is:

Enzyme compilation failed due to an internal error.
 Please open an issue with the code to reproduce and full error log on github.com/EnzymeAD/Enzyme.jl
 To toggle more information for debugging (needed for bug reports), set Enzyme.Compiler.VERBOSE_ERRORS[] = true (default false)
Current scope: 

Here is my working environment specification:

IDE: Visual Studio Code, Jupyter Notebook mode
OS: Debian 12 (kernel version 6.1.0-34-amd64 (64-bit))
Julia version: 1.11.5
Enzyme version: v0.13.43
Working on CPU
Processor: 4 × Intel® Core™ i5-6300U CPU @ 2.40GHz
Memory: 30.6 GiB of RAM
Manufacturer: LENOVO

muis-code avatar May 16 '25 07:05 muis-code

Let me know if you need any clarification, and thanks for all the amazing work on Enzyme!

muis-code avatar May 16 '25 08:05 muis-code

Could you try with Julia 1.10?

vchuravy avatar May 16 '25 08:05 vchuravy

we just released a fix for a similar issue on 1.11, @muis-code can you retry?

wsmoses avatar Jul 03 '25 04:07 wsmoses

gentle ping @muis-code

wsmoses avatar Sep 18 '25 13:09 wsmoses

on 1.10 this seems to work successfully without issue.

on 1.11 the illegal replacement is gone, but now it hits a segfault

wsmoses avatar Nov 09 '25 05:11 wsmoses