Enzyme.jl icon indicating copy to clipboard operation
Enzyme.jl copied to clipboard

Flux+Enzyme internal compiler error

Open freddycct opened this issue 7 months ago • 8 comments

using Flux
using Enzyme

clf = Dense(3 => 2)
function forward(clf::Dense, x::Vector{Float32})::Float32
	y = clf(x)
	log_y = Flux.logsoftmax(y)
	log_score = first(log_y)
	return log_score
end

x = rand(Float32, 3)
@show forward(clf, x)
grads = Flux.gradient(forward, Enzyme.Duplicated(clf), Const(x))
Enzyme compilation failed due to an internal error.

 Please open an issue with the code to reproduce and full error log on github.com/EnzymeAD/Enzyme.jl

 To toggle more information for debugging (needed for bug reports), set Enzyme.Compiler.VERBOSE_ERRORS[] = true (default false)



Stacktrace:

  [1] copy

    @ ./array.jl:350

  [2] unaliascopy

    @ ./abstractarray.jl:1516

  [3] unalias

    @ ./abstractarray.jl:1500

  [4] broadcast_unalias

    @ ./broadcast.jl:946

  [5] preprocess

    @ ./broadcast.jl:953

  [6] preprocess_args

    @ ./broadcast.jl:956

  [7] preprocess_args

    @ ./broadcast.jl:955

  [8] preprocess

    @ ./broadcast.jl:952

  [9] preprocess_args

    @ ./broadcast.jl:956

 [10] preprocess_args (repeats 2 times)

    @ ./broadcast.jl:955

 [11] preprocess

    @ ./broadcast.jl:952

 [12] override_bc_copyto!

    @ ~/.julia/packages/Enzyme/g1jMR/src/compiler/interpreter.jl:798

 [13] copyto!

    @ ./broadcast.jl:925

 [14] materialize!

    @ ./broadcast.jl:883

 [15] materialize!

    @ ./broadcast.jl:880

 [16] #logsoftmax!#202

    @ ~/.julia/packages/NNlib/CGMj3/src/softmax.jl:117

freddycct avatar Apr 16 '25 18:04 freddycct

Curiously the error goes away for a matrix instead of vector input:

julia> x = rand(Float32, 3, 1);  # now a Matrix

julia> function forward(clf::Dense, x)::Float32  # signature widened to allow this
               y = clf(x)
               log_y = Flux.logsoftmax(y)
               log_score = first(log_y)
               return log_score
       end
forward (generic function with 2 methods)

julia> forward(clf, x)
-1.4793367f0

julia> grads = Flux.gradient(forward, Enzyme.Duplicated(clf), Const(x))
((weight = Float32[0.37867802 0.37851515 0.19702016; -0.37867802 -0.37851515 -0.19702016], bias = Float32[0.7722113, -0.7722113], σ = nothing), nothing)

julia> x = rand(Float32, 3);  # back to vector case

julia> forward(clf, x)
-1.7434618f0

julia> grads = Flux.gradient(forward, Enzyme.Duplicated(clf), Const(x))
ERROR: Enzyme compilation failed due to an internal error.

(@v1.11) pkg> st Enzyme Flux
Status `~/.julia/environments/v1.11/Project.toml`
  [7da242da] Enzyme v0.13.35
  [587475ba] Flux v0.16.3

julia> versioninfo()
Julia Version 1.11.5
Commit 760b2e5b739 (2025-04-14 06:53 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: macOS (arm64-apple-darwin24.0.0)
  CPU: 11 × Apple M3 Pro
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, apple-m2)
Threads: 4 default, 0 interactive, 2 GC (on 5 virtual cores)
Environment:
  JULIA_NUM_THREADS = 4

Also without using Flux.gradient:

julia> x = rand(Float32, 3, 1);

julia> Enzyme.gradient(Reverse, forward, clf, x)
(Dense(3 => 2), Float32[-1.2737786; -0.24676809; -1.2240989;;])

julia> x = rand(Float32, 3);

julia> Enzyme.gradient(Reverse, forward, clf, x)
ERROR: Enzyme compilation failed due to an internal error.
 Please open an issue with the code to reproduce and full error log on github.com/EnzymeAD/Enzyme.jl
 To toggle more information for debugging (needed for bug reports), set Enzyme.Compiler.VERBOSE_ERRORS[] = true (default false)

Illegal replace ficticious phi for:   %_replacementA393 = phi {} addrspace(10)* , !dbg !388 of   %287 = call nonnull "enzyme_type"="{[-1]:Pointer}" {} addrspace(10)* @jl_genericmemory_copy_slice({} addrspace(10)* nonnull %21, i64 %286, i64 %14) #72, !dbg !395

Verbose output is very long:

julia> Enzyme.Compiler.VERBOSE_ERRORS[] = true
true

julia> Enzyme.gradient(Reverse, forward, clf, x)
ERROR: Enzyme compilation failed due to an internal error.
 Please open an issue with the code to reproduce and full error log on github.com/EnzymeAD/Enzyme.jl
 To toggle more information for debugging (needed for bug reports), set Enzyme.Compiler.VERBOSE_ERRORS[] = true (default false)
Current scope: 
; Function Attrs: mustprogress willreturn
define internal fastcc nonnull "enzyme_type"="{[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@float, [-1,8]:Pointer, [-1,8,0]:Integer, [-1,8,1]:Integer, [-1,8,2]:Integer, [-1,8,3]:Integer, [-1,8,4]:Integer, [-1,8,5]:Integer, [-1,8,6]:Integer, [-1,8,7]:Integer, [-1,8,8]:Pointer, [-1,8,8,-1]:Float@float, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer}" {} addrspace(10)* @preprocess_julia__logsoftmax__202_44032({} addrspace(10)* noundef nonnull align 8 dereferenceable(24) "enzyme_type"="{[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@float, [-1,8]:Pointer, [-1,8,0]:Integer, [-1,8,1]:Integer, [-1,8,2]:Integer, [-1,8,3]:Integer, [-1,8,4]:Integer, [-1,8,5]:Integer, [-1,8,6]:Integer, [-1,8,7]:Integer, [-1,8,8]:Pointer, [-1,8,8,-1]:Float@float, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer}" "enzymejl_parmtype"="4472460368" "enzymejl_parmtype_ref"="2" %0, {} addrspace(10)* noundef nonnull align 8 dereferenceable(24) "enzyme_type"="{[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@float, [-1,8]:Pointer, [-1,8,0]:Integer, [-1,8,1]:Integer, [-1,8,2]:Integer, [-1,8,3]:Integer, [-1,8,4]:Integer, [-1,8,5]:Integer, [-1,8,6]:Integer, [-1,8,7]:Integer, [-1,8,8]:Pointer, [-1,8,8,-1]:Float@float, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer}" "enzymejl_parmtype"="4472460368" "enzymejl_parmtype_ref"="2" %1) unnamed_addr #64 !dbg !2808 {
top:
  %2 = call noalias nonnull dereferenceable(8) dereferenceable_or_null(8) i8* @malloc(i64 8), !enzyme_fromstack !112
  %3 = bitcast i8* %2 to [1 x i64]*, !enzyme_caststack !0
  %pgcstack = call {}*** @julia.get_pgcstack() #68
  %current_task1457 = getelementptr inbounds {}**, {}*** %pgcstack, i64 -14
  %4 = bitcast {}*** %current_task1457 to {}*
  %ptls_field458 = getelementptr inbounds {}**, {}*** %pgcstack, i64 2
  %5 = bitcast {}*** %ptls_field458 to i64***
  %ptls_load459460 = load i64**, i64*** %5, align 8, !tbaa !43
  %6 = getelementptr inbounds i64*, i64** %ptls_load459460, i64 2
  %safepoint = load i64*, i64** %6, align 8, !tbaa !47
  fence syncscope("singlethread") seq_cst
  call void @julia.safepoint(i64* %safepoint) #68, !dbg !2809
  fence syncscope("singlethread") seq_cst
  %7 = call noalias nonnull align 8 dereferenceable(8) "enzyme_inactive" "enzyme_type"="{[-1]:Pointer, [-1,-1]:Integer}" {} addrspace(10)* @ijl_box_int64(i64 signext 1) #69, !dbg !2810
  %8 = call nonnull "enzyme_type"="{[-1]:Pointer}" {} addrspace(10)* ({} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32, {} addrspace(10)*)*, {} addrspace(10)*, {} addrspace(10)*, ...) @julia.call2({} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32, {} addrspace(10)*)* nonnull @ijl_invoke, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 5834769488 to {}*) to {} addrspace(10)*), {} addrspace(10)* addrspacecast ({}* inttoptr (i64 4672712992 to {}*) to {} addrspace(10)*), {} addrspace(10)* addrspacecast ({}* inttoptr (i64 4675544064 to {}*) to {} addrspace(10)*), {} addrspace(10)* addrspacecast ({}* inttoptr (i64 4670938128 to {}*) to {} addrspace(10)*), {} addrspace(10)* addrspacecast ({}* inttoptr (i64 5903779200 to {}*) to {} addrspace(10)*), {} addrspace(10)* nonnull %1, {} addrspace(10)* nonnull %7) #70, !dbg !2810
  %9 = addrspacecast {} addrspace(10)* %8 to {} addrspace(11)*, !dbg !2817
  %10 = bitcast {} addrspace(10)* %8 to i8 addrspace(10)*, !dbg !2817
  %11 = addrspacecast i8 addrspace(10)* %10 to i8 addrspace(11)*, !dbg !2817
  %12 = getelementptr inbounds i8, i8 addrspace(11)* %11, i64 16, !dbg !2817
  %13 = bitcast i8 addrspace(11)* %12 to i64 addrspace(11)*, !dbg !2817
  %14 = load i64, i64 addrspace(11)* %13, align 8, !dbg !2817, !tbaa !50, !alias.scope !424, !noalias !454
  %.not = icmp eq i64 %14, 0, !dbg !2824
  br i1 %.not, label %L85, label %L36, !dbg !2818

L36:                                              ; preds = %top
  %15 = bitcast {} addrspace(10)* %8 to { i8*, {} addrspace(10)* } addrspace(10)*, !dbg !2825
  %16 = addrspacecast { i8*, {} addrspace(10)* } addrspace(10)* %15 to { i8*, {} addrspace(10)* } addrspace(11)*, !dbg !2825
  %17 = bitcast {} addrspace(10)* %8 to {} addrspace(10)** addrspace(10)*, !dbg !2825
  %18 = addrspacecast {} addrspace(10)** addrspace(10)* %17 to {} addrspace(10)** addrspace(11)*, !dbg !2825
  %19 = load {} addrspace(10)**, {} addrspace(10)** addrspace(11)* %18, align 8, !dbg !2825, !tbaa !134, !alias.scope !126, !noalias !127, !enzyme_nocache !0
  %20 = getelementptr inbounds { i8*, {} addrspace(10)* }, { i8*, {} addrspace(10)* } addrspace(11)* %16, i64 0, i32 1, !dbg !2825
  %21 = load {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %20, align 8, !dbg !2825, !tbaa !134, !alias.scope !126, !noalias !127, !dereferenceable_or_null !111, !align !112, !enzyme_type !121
  %22 = call {} addrspace(10)* addrspace(13)* @julia.gc_loaded({} addrspace(10)* %21, {} addrspace(10)** %19) #68, !dbg !2825
  %23 = bitcast {} addrspace(10)* addrspace(13)* %22 to float addrspace(13)*, !dbg !2825
  %value_phi4584 = load float, float addrspace(13)* %23, align 4, !dbg !2826, !tbaa !496, !alias.scope !78, !noalias !79
  %24 = fsub float %value_phi4584, %value_phi4584, !dbg !2827
  %25 = fcmp uno float %24, 0.000000e+00, !dbg !2830
  br i1 %25, label %L399, label %L44.lr.ph, !dbg !2832

L44.lr.ph:                                        ; preds = %L36
  %26 = add i64 %14, 1, !dbg !2832
  br label %L44, !dbg !2832

L44:                                              ; preds = %L73, %L44.lr.ph
  %iv = phi i64 [ %iv.next, %L73 ], [ 0, %L44.lr.ph ]
  %27 = add i64 %iv, 2, !dbg !2833
  %iv.next = add nuw nsw i64 %iv, 1, !dbg !2833
  %exitcond591.not = icmp eq i64 %27, %26, !dbg !2833
  br i1 %exitcond591.not, label %L85.loopexit, label %L73, !dbg !2834



...




invertguard_pass165:                              ; preds = %invertguard_exit166
  br label %invertguard_exit161

invertguard_exit166:                              ; preds = %invertL563, %invertL549
  %_unwrap854 = extractvalue { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, i8*, {} addrspace(10)*, i8*, i8*, {} addrspace(10)*, i8*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, i8*, i8*, {} addrspace(10)*, i8*, i8*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, i64, {} addrspace(10)*, i1, i1*, i64, i64, {} addrspace(10)*, {} addrspace(10)**, i64, {} addrspace(10)*, {} addrspace(10)**, i64, i64, i64, {} addrspace(10)*, {} addrspace(10)**, {} addrspace(10)*, {} addrspace(10)**, i64, i64, {} addrspace(10)*, {} addrspace(10)**, i1*, float*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)**, {} addrspace(10)*, i1, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)**, {} addrspace(10)*, i1, {} addrspace(10)*, {} addrspace(10)**, {} addrspace(10)*, i1, {} addrspace(10)*, {} addrspace(10)*, i1, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)**, {} addrspace(10)*, i1, {} addrspace(10)*, {} addrspace(10)**, {} addrspace(10)*, i1 } %tapeArg, 35, !dbg !4113
  %_unwrap791 = bitcast {} addrspace(10)* %_unwrap854 to {} addrspace(10)* addrspace(10)*
  %_unwrap792 = addrspacecast {} addrspace(10)* addrspace(10)* %_unwrap791 to {} addrspace(10)* addrspace(11)*
  %_unwrap793 = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %_unwrap792, i64 2
  %_unwrap849 = extractvalue { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, i8*, {} addrspace(10)*, i8*, i8*, {} addrspace(10)*, i8*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, i8*, i8*, {} addrspace(10)*, i8*, i8*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, i64, {} addrspace(10)*, i1, i1*, i64, i64, {} addrspace(10)*, {} addrspace(10)**, i64, {} addrspace(10)*, {} addrspace(10)**, i64, i64, i64, {} addrspace(10)*, {} addrspace(10)**, {} addrspace(10)*, {} addrspace(10)**, i64, i64, {} addrspace(10)*, {} addrspace(10)**, i1*, float*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)**, {} addrspace(10)*, i1, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)**, {} addrspace(10)*, i1, {} addrspace(10)*, {} addrspace(10)**, {} addrspace(10)*, i1, {} addrspace(10)*, {} addrspace(10)*, i1, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)**, {} addrspace(10)*, i1, {} addrspace(10)*, {} addrspace(10)**, {} addrspace(10)*, i1 } %tapeArg, 81, !dbg !4309
  %_unwrap795 = addrspacecast {} addrspace(10)** %_unwrap849 to {} addrspace(10)* addrspace(11)*
  %.not536_unwrap = icmp eq {} addrspace(10)* addrspace(11)* %_unwrap793, %_unwrap795
  br i1 %.not536_unwrap, label %invertguard_exit161, label %invertguard_pass165
}

LLVM.CallInst(%287 = call nonnull "enzyme_type"="{[-1]:Pointer}" {} addrspace(10)* @jl_genericmemory_copy_slice({} addrspace(10)* nonnull %21, i64 %286, i64 %14) #72, !dbg !395)
LLVM.PHIInst(%_replacementA393 = phi {} addrspace(10)* , !dbg !401)


Stacktrace:
  [1] copy
    @ ./array.jl:350
  [2] unaliascopy
    @ ./abstractarray.jl:1516
  [3] unalias
    @ ./abstractarray.jl:1500
  [4] broadcast_unalias
    @ ./broadcast.jl:946
  [5] preprocess
    @ ./broadcast.jl:953
  [6] preprocess_args
    @ ./broadcast.jl:956
  [7] preprocess_args
    @ ./broadcast.jl:955
  [8] preprocess
    @ ./broadcast.jl:952
  [9] preprocess_args
    @ ./broadcast.jl:956
 [10] preprocess_args (repeats 2 times)
    @ ./broadcast.jl:955
 [11] preprocess
    @ ./broadcast.jl:952
 [12] override_bc_copyto!
    @ ~/.julia/packages/Enzyme/g1jMR/src/compiler/interpreter.jl:798
 [13] copyto!
    @ ./broadcast.jl:925
 [14] materialize!
    @ ./broadcast.jl:883
 [15] materialize!
    @ ./broadcast.jl:880
 [16] #logsoftmax!#202
    @ ~/.julia/packages/NNlib/CGMj3/src/softmax.jl:117

Stacktrace:
  [1] julia_error(msg::String, val::Ptr{…}, errtype::Enzyme.API.ErrorType, data::Ptr{…}, data2::Ptr{…}, B::Ptr{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/g1jMR/src/errors.jl:384
  [2] julia_error(cstr::Cstring, val::Ptr{…}, errtype::Enzyme.API.ErrorType, data::Ptr{…}, data2::Ptr{…}, B::Ptr{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/g1jMR/src/errors.jl:210
  [3] EnzymeCreatePrimalAndGradient(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{…}, TA::Enzyme.TypeAnalysis, returnValue::Bool, dretUsed::Bool, mode::Enzyme.API.CDerivativeMode, runtimeActivity::Bool, width::Int64, additionalArg::Ptr{…}, forceAnonymousTape::Bool, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{…}, augmented::Ptr{…}, atomicAdd::Bool)
    @ Enzyme.API ~/.julia/packages/Enzyme/g1jMR/src/api.jl:269
  [4] enzyme!(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::NTuple{…} where N, returnPrimal::Bool, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/g1jMR/src/compiler.jl:1745
  [5] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/g1jMR/src/compiler.jl:4655
  [6] codegen
    @ ~/.julia/packages/Enzyme/g1jMR/src/compiler.jl:3441 [inlined]
  [7] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/g1jMR/src/compiler.jl:5515
  [8] _thunk
    @ ~/.julia/packages/Enzyme/g1jMR/src/compiler.jl:5515 [inlined]
  [9] cached_compilation
    @ ~/.julia/packages/Enzyme/g1jMR/src/compiler.jl:5567 [inlined]
 [10] thunkbase(mi::Core.MethodInstance, World::UInt64, FA::Type{…}, A::Type{…}, TT::Type, Mode::Enzyme.API.CDerivativeMode, width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, edges::Vector{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/g1jMR/src/compiler.jl:5678
 [11] thunk_generator(world::UInt64, source::LineNumberNode, FA::Type, A::Type, TT::Type, Mode::Enzyme.API.CDerivativeMode, Width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, self::Any, fakeworld::Any, fa::Type, a::Type, tt::Type, mode::Type, width::Type, modifiedbetween::Type, returnprimal::Type, shadowinit::Type, abi::Type, erriffuncwritten::Type, runtimeactivity::Type)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/g1jMR/src/compiler.jl:5863
 [12] autodiff
    @ ~/.julia/packages/Enzyme/g1jMR/src/Enzyme.jl:485 [inlined]
 [13] autodiff
    @ ~/.julia/packages/Enzyme/g1jMR/src/Enzyme.jl:524 [inlined]
 [14] macro expansion
    @ ~/.julia/packages/Enzyme/g1jMR/src/sugar.jl:275 [inlined]
 [15] gradient(rm::ReverseMode{…}, f::typeof(forward), x::Dense{…}, args::Vector{…})
    @ Enzyme ~/.julia/packages/Enzyme/g1jMR/src/sugar.jl:262
 [16] top-level scope
    @ REPL[35]:1
Some type information was truncated. Use `show(err)` to see complete types.

mcabbott avatar Apr 16 '25 18:04 mcabbott

Runs fine on my machine. What versions of julia, Flux, and Enzyme are you on?

MasonProtter avatar Apr 16 '25 19:04 MasonProtter

Runs fine on my machine. What versions of julia, Flux, and Enzyme are you on?

julia> versioninfo()
Julia Version 1.11.4
Commit 8561cc3d68d (2025-03-10 11:36 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 96 × AMD EPYC 7571
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver1)
Threads: 1 default, 0 interactive, 1 GC (on 96 virtual cores

Status `~/vscode/enzyme_test/Project.toml`
  [7da242da] Enzyme v0.13.35
  [587475ba] Flux v0.16.3

freddycct avatar Apr 16 '25 19:04 freddycct

Btw, it happens when Flux.logsoftmax is used.

freddycct avatar Apr 16 '25 19:04 freddycct

I took from this example,

julia> using Flux, Enzyme

julia> model = Chain(Dense(28^2 => 32, sigmoid), Dense(32 => 10), softmax);
 # from model zoo

julia> dup_model = Enzyme.Duplicated(model)  # this allocates space for the
gradient
Duplicated(
  Chain(
    Dense(784 => 32, σ),                # 25_120 parameters
    Dense(32 => 10),                    # 330 parameters
    NNlib.softmax,
  ),
  # norm(∇) ≈ 0.0f0
)                   # Total: 4 arrays, 25_450 parameters, 199.391 KiB.

julia> x1 = randn32(28*28, 1);  # fake image

julia> y1 = [i==3 for i in 0:9];  # fake label

julia> grads_f = Flux.gradient((m,x,y) -> sum(abs2, m(x) .- y), dup_model,
Const(x1), Const(y1))  # uses Enzyme
((layers = ((weight = Float32[-0.010354728 0.032972857 …
    -0.0014538406], σ = nothing), nothing),), nothing, nothing)

freddycct avatar Apr 16 '25 20:04 freddycct

This MWE confuses me. What is Flux.gradient and why does it take Enzyme data-structures?

vchuravy avatar May 06 '25 08:05 vchuravy

https://fluxml.ai/Flux.jl/stable/reference/training/enzyme/#Flux.gradient-Tuple%7BAny,%20Vararg%7BUnion%7BEnzymeCore.Const,%20EnzymeCore.Duplicated%7D%7D%7D

MasonProtter avatar May 06 '25 09:05 MasonProtter

The code that Flux.gradient(f, ::Duplicated) ends up calling is here, make_zero! then autodiff:

https://github.com/FluxML/Flux.jl/blob/0e36af98f6fc5b7f3c95fe819a02172cfaaaf777/ext/FluxEnzymeExt/FluxEnzymeExt.jl#L44-L51

mcabbott avatar May 06 '25 18:05 mcabbott

we recently released patches that fix similar error messages, can you check if it persists?

wsmoses avatar Sep 18 '25 13:09 wsmoses

On 1.10 and 1.11 this doesn't error for me. If it persists, please reopen!

wsmoses avatar Nov 09 '25 04:11 wsmoses