Enzyme.jl
Enzyme.jl copied to clipboard
Flux+Enzyme internal compiler error
using Flux
using Enzyme
clf = Dense(3 => 2)
function forward(clf::Dense, x::Vector{Float32})::Float32
y = clf(x)
log_y = Flux.logsoftmax(y)
log_score = first(log_y)
return log_score
end
x = rand(Float32, 3)
@show forward(clf, x)
grads = Flux.gradient(forward, Enzyme.Duplicated(clf), Const(x))
Enzyme compilation failed due to an internal error.
Please open an issue with the code to reproduce and full error log on github.com/EnzymeAD/Enzyme.jl
To toggle more information for debugging (needed for bug reports), set Enzyme.Compiler.VERBOSE_ERRORS[] = true (default false)
Stacktrace:
[1] copy
@ ./array.jl:350
[2] unaliascopy
@ ./abstractarray.jl:1516
[3] unalias
@ ./abstractarray.jl:1500
[4] broadcast_unalias
@ ./broadcast.jl:946
[5] preprocess
@ ./broadcast.jl:953
[6] preprocess_args
@ ./broadcast.jl:956
[7] preprocess_args
@ ./broadcast.jl:955
[8] preprocess
@ ./broadcast.jl:952
[9] preprocess_args
@ ./broadcast.jl:956
[10] preprocess_args (repeats 2 times)
@ ./broadcast.jl:955
[11] preprocess
@ ./broadcast.jl:952
[12] override_bc_copyto!
@ ~/.julia/packages/Enzyme/g1jMR/src/compiler/interpreter.jl:798
[13] copyto!
@ ./broadcast.jl:925
[14] materialize!
@ ./broadcast.jl:883
[15] materialize!
@ ./broadcast.jl:880
[16] #logsoftmax!#202
@ ~/.julia/packages/NNlib/CGMj3/src/softmax.jl:117
Curiously the error goes away for a matrix instead of vector input:
julia> x = rand(Float32, 3, 1); # now a Matrix
julia> function forward(clf::Dense, x)::Float32 # signature widened to allow this
y = clf(x)
log_y = Flux.logsoftmax(y)
log_score = first(log_y)
return log_score
end
forward (generic function with 2 methods)
julia> forward(clf, x)
-1.4793367f0
julia> grads = Flux.gradient(forward, Enzyme.Duplicated(clf), Const(x))
((weight = Float32[0.37867802 0.37851515 0.19702016; -0.37867802 -0.37851515 -0.19702016], bias = Float32[0.7722113, -0.7722113], σ = nothing), nothing)
julia> x = rand(Float32, 3); # back to vector case
julia> forward(clf, x)
-1.7434618f0
julia> grads = Flux.gradient(forward, Enzyme.Duplicated(clf), Const(x))
ERROR: Enzyme compilation failed due to an internal error.
(@v1.11) pkg> st Enzyme Flux
Status `~/.julia/environments/v1.11/Project.toml`
[7da242da] Enzyme v0.13.35
[587475ba] Flux v0.16.3
julia> versioninfo()
Julia Version 1.11.5
Commit 760b2e5b739 (2025-04-14 06:53 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: macOS (arm64-apple-darwin24.0.0)
CPU: 11 × Apple M3 Pro
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, apple-m2)
Threads: 4 default, 0 interactive, 2 GC (on 5 virtual cores)
Environment:
JULIA_NUM_THREADS = 4
Also without using Flux.gradient:
julia> x = rand(Float32, 3, 1);
julia> Enzyme.gradient(Reverse, forward, clf, x)
(Dense(3 => 2), Float32[-1.2737786; -0.24676809; -1.2240989;;])
julia> x = rand(Float32, 3);
julia> Enzyme.gradient(Reverse, forward, clf, x)
ERROR: Enzyme compilation failed due to an internal error.
Please open an issue with the code to reproduce and full error log on github.com/EnzymeAD/Enzyme.jl
To toggle more information for debugging (needed for bug reports), set Enzyme.Compiler.VERBOSE_ERRORS[] = true (default false)
Illegal replace ficticious phi for: %_replacementA393 = phi {} addrspace(10)* , !dbg !388 of %287 = call nonnull "enzyme_type"="{[-1]:Pointer}" {} addrspace(10)* @jl_genericmemory_copy_slice({} addrspace(10)* nonnull %21, i64 %286, i64 %14) #72, !dbg !395
Verbose output is very long:
julia> Enzyme.Compiler.VERBOSE_ERRORS[] = true
true
julia> Enzyme.gradient(Reverse, forward, clf, x)
ERROR: Enzyme compilation failed due to an internal error.
Please open an issue with the code to reproduce and full error log on github.com/EnzymeAD/Enzyme.jl
To toggle more information for debugging (needed for bug reports), set Enzyme.Compiler.VERBOSE_ERRORS[] = true (default false)
Current scope:
; Function Attrs: mustprogress willreturn
define internal fastcc nonnull "enzyme_type"="{[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@float, [-1,8]:Pointer, [-1,8,0]:Integer, [-1,8,1]:Integer, [-1,8,2]:Integer, [-1,8,3]:Integer, [-1,8,4]:Integer, [-1,8,5]:Integer, [-1,8,6]:Integer, [-1,8,7]:Integer, [-1,8,8]:Pointer, [-1,8,8,-1]:Float@float, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer}" {} addrspace(10)* @preprocess_julia__logsoftmax__202_44032({} addrspace(10)* noundef nonnull align 8 dereferenceable(24) "enzyme_type"="{[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@float, [-1,8]:Pointer, [-1,8,0]:Integer, [-1,8,1]:Integer, [-1,8,2]:Integer, [-1,8,3]:Integer, [-1,8,4]:Integer, [-1,8,5]:Integer, [-1,8,6]:Integer, [-1,8,7]:Integer, [-1,8,8]:Pointer, [-1,8,8,-1]:Float@float, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer}" "enzymejl_parmtype"="4472460368" "enzymejl_parmtype_ref"="2" %0, {} addrspace(10)* noundef nonnull align 8 dereferenceable(24) "enzyme_type"="{[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@float, [-1,8]:Pointer, [-1,8,0]:Integer, [-1,8,1]:Integer, [-1,8,2]:Integer, [-1,8,3]:Integer, [-1,8,4]:Integer, [-1,8,5]:Integer, [-1,8,6]:Integer, [-1,8,7]:Integer, [-1,8,8]:Pointer, [-1,8,8,-1]:Float@float, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer}" "enzymejl_parmtype"="4472460368" "enzymejl_parmtype_ref"="2" %1) unnamed_addr #64 !dbg !2808 {
top:
%2 = call noalias nonnull dereferenceable(8) dereferenceable_or_null(8) i8* @malloc(i64 8), !enzyme_fromstack !112
%3 = bitcast i8* %2 to [1 x i64]*, !enzyme_caststack !0
%pgcstack = call {}*** @julia.get_pgcstack() #68
%current_task1457 = getelementptr inbounds {}**, {}*** %pgcstack, i64 -14
%4 = bitcast {}*** %current_task1457 to {}*
%ptls_field458 = getelementptr inbounds {}**, {}*** %pgcstack, i64 2
%5 = bitcast {}*** %ptls_field458 to i64***
%ptls_load459460 = load i64**, i64*** %5, align 8, !tbaa !43
%6 = getelementptr inbounds i64*, i64** %ptls_load459460, i64 2
%safepoint = load i64*, i64** %6, align 8, !tbaa !47
fence syncscope("singlethread") seq_cst
call void @julia.safepoint(i64* %safepoint) #68, !dbg !2809
fence syncscope("singlethread") seq_cst
%7 = call noalias nonnull align 8 dereferenceable(8) "enzyme_inactive" "enzyme_type"="{[-1]:Pointer, [-1,-1]:Integer}" {} addrspace(10)* @ijl_box_int64(i64 signext 1) #69, !dbg !2810
%8 = call nonnull "enzyme_type"="{[-1]:Pointer}" {} addrspace(10)* ({} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32, {} addrspace(10)*)*, {} addrspace(10)*, {} addrspace(10)*, ...) @julia.call2({} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32, {} addrspace(10)*)* nonnull @ijl_invoke, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 5834769488 to {}*) to {} addrspace(10)*), {} addrspace(10)* addrspacecast ({}* inttoptr (i64 4672712992 to {}*) to {} addrspace(10)*), {} addrspace(10)* addrspacecast ({}* inttoptr (i64 4675544064 to {}*) to {} addrspace(10)*), {} addrspace(10)* addrspacecast ({}* inttoptr (i64 4670938128 to {}*) to {} addrspace(10)*), {} addrspace(10)* addrspacecast ({}* inttoptr (i64 5903779200 to {}*) to {} addrspace(10)*), {} addrspace(10)* nonnull %1, {} addrspace(10)* nonnull %7) #70, !dbg !2810
%9 = addrspacecast {} addrspace(10)* %8 to {} addrspace(11)*, !dbg !2817
%10 = bitcast {} addrspace(10)* %8 to i8 addrspace(10)*, !dbg !2817
%11 = addrspacecast i8 addrspace(10)* %10 to i8 addrspace(11)*, !dbg !2817
%12 = getelementptr inbounds i8, i8 addrspace(11)* %11, i64 16, !dbg !2817
%13 = bitcast i8 addrspace(11)* %12 to i64 addrspace(11)*, !dbg !2817
%14 = load i64, i64 addrspace(11)* %13, align 8, !dbg !2817, !tbaa !50, !alias.scope !424, !noalias !454
%.not = icmp eq i64 %14, 0, !dbg !2824
br i1 %.not, label %L85, label %L36, !dbg !2818
L36: ; preds = %top
%15 = bitcast {} addrspace(10)* %8 to { i8*, {} addrspace(10)* } addrspace(10)*, !dbg !2825
%16 = addrspacecast { i8*, {} addrspace(10)* } addrspace(10)* %15 to { i8*, {} addrspace(10)* } addrspace(11)*, !dbg !2825
%17 = bitcast {} addrspace(10)* %8 to {} addrspace(10)** addrspace(10)*, !dbg !2825
%18 = addrspacecast {} addrspace(10)** addrspace(10)* %17 to {} addrspace(10)** addrspace(11)*, !dbg !2825
%19 = load {} addrspace(10)**, {} addrspace(10)** addrspace(11)* %18, align 8, !dbg !2825, !tbaa !134, !alias.scope !126, !noalias !127, !enzyme_nocache !0
%20 = getelementptr inbounds { i8*, {} addrspace(10)* }, { i8*, {} addrspace(10)* } addrspace(11)* %16, i64 0, i32 1, !dbg !2825
%21 = load {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %20, align 8, !dbg !2825, !tbaa !134, !alias.scope !126, !noalias !127, !dereferenceable_or_null !111, !align !112, !enzyme_type !121
%22 = call {} addrspace(10)* addrspace(13)* @julia.gc_loaded({} addrspace(10)* %21, {} addrspace(10)** %19) #68, !dbg !2825
%23 = bitcast {} addrspace(10)* addrspace(13)* %22 to float addrspace(13)*, !dbg !2825
%value_phi4584 = load float, float addrspace(13)* %23, align 4, !dbg !2826, !tbaa !496, !alias.scope !78, !noalias !79
%24 = fsub float %value_phi4584, %value_phi4584, !dbg !2827
%25 = fcmp uno float %24, 0.000000e+00, !dbg !2830
br i1 %25, label %L399, label %L44.lr.ph, !dbg !2832
L44.lr.ph: ; preds = %L36
%26 = add i64 %14, 1, !dbg !2832
br label %L44, !dbg !2832
L44: ; preds = %L73, %L44.lr.ph
%iv = phi i64 [ %iv.next, %L73 ], [ 0, %L44.lr.ph ]
%27 = add i64 %iv, 2, !dbg !2833
%iv.next = add nuw nsw i64 %iv, 1, !dbg !2833
%exitcond591.not = icmp eq i64 %27, %26, !dbg !2833
br i1 %exitcond591.not, label %L85.loopexit, label %L73, !dbg !2834
...
invertguard_pass165: ; preds = %invertguard_exit166
br label %invertguard_exit161
invertguard_exit166: ; preds = %invertL563, %invertL549
%_unwrap854 = extractvalue { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, i8*, {} addrspace(10)*, i8*, i8*, {} addrspace(10)*, i8*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, i8*, i8*, {} addrspace(10)*, i8*, i8*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, i64, {} addrspace(10)*, i1, i1*, i64, i64, {} addrspace(10)*, {} addrspace(10)**, i64, {} addrspace(10)*, {} addrspace(10)**, i64, i64, i64, {} addrspace(10)*, {} addrspace(10)**, {} addrspace(10)*, {} addrspace(10)**, i64, i64, {} addrspace(10)*, {} addrspace(10)**, i1*, float*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)**, {} addrspace(10)*, i1, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)**, {} addrspace(10)*, i1, {} addrspace(10)*, {} addrspace(10)**, {} addrspace(10)*, i1, {} addrspace(10)*, {} addrspace(10)*, i1, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)**, {} addrspace(10)*, i1, {} addrspace(10)*, {} addrspace(10)**, {} addrspace(10)*, i1 } %tapeArg, 35, !dbg !4113
%_unwrap791 = bitcast {} addrspace(10)* %_unwrap854 to {} addrspace(10)* addrspace(10)*
%_unwrap792 = addrspacecast {} addrspace(10)* addrspace(10)* %_unwrap791 to {} addrspace(10)* addrspace(11)*
%_unwrap793 = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %_unwrap792, i64 2
%_unwrap849 = extractvalue { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, i8*, {} addrspace(10)*, i8*, i8*, {} addrspace(10)*, i8*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, i8*, i8*, {} addrspace(10)*, i8*, i8*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, i64, {} addrspace(10)*, i1, i1*, i64, i64, {} addrspace(10)*, {} addrspace(10)**, i64, {} addrspace(10)*, {} addrspace(10)**, i64, i64, i64, {} addrspace(10)*, {} addrspace(10)**, {} addrspace(10)*, {} addrspace(10)**, i64, i64, {} addrspace(10)*, {} addrspace(10)**, i1*, float*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)**, {} addrspace(10)*, i1, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)**, {} addrspace(10)*, i1, {} addrspace(10)*, {} addrspace(10)**, {} addrspace(10)*, i1, {} addrspace(10)*, {} addrspace(10)*, i1, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)**, {} addrspace(10)*, i1, {} addrspace(10)*, {} addrspace(10)**, {} addrspace(10)*, i1 } %tapeArg, 81, !dbg !4309
%_unwrap795 = addrspacecast {} addrspace(10)** %_unwrap849 to {} addrspace(10)* addrspace(11)*
%.not536_unwrap = icmp eq {} addrspace(10)* addrspace(11)* %_unwrap793, %_unwrap795
br i1 %.not536_unwrap, label %invertguard_exit161, label %invertguard_pass165
}
LLVM.CallInst(%287 = call nonnull "enzyme_type"="{[-1]:Pointer}" {} addrspace(10)* @jl_genericmemory_copy_slice({} addrspace(10)* nonnull %21, i64 %286, i64 %14) #72, !dbg !395)
LLVM.PHIInst(%_replacementA393 = phi {} addrspace(10)* , !dbg !401)
Stacktrace:
[1] copy
@ ./array.jl:350
[2] unaliascopy
@ ./abstractarray.jl:1516
[3] unalias
@ ./abstractarray.jl:1500
[4] broadcast_unalias
@ ./broadcast.jl:946
[5] preprocess
@ ./broadcast.jl:953
[6] preprocess_args
@ ./broadcast.jl:956
[7] preprocess_args
@ ./broadcast.jl:955
[8] preprocess
@ ./broadcast.jl:952
[9] preprocess_args
@ ./broadcast.jl:956
[10] preprocess_args (repeats 2 times)
@ ./broadcast.jl:955
[11] preprocess
@ ./broadcast.jl:952
[12] override_bc_copyto!
@ ~/.julia/packages/Enzyme/g1jMR/src/compiler/interpreter.jl:798
[13] copyto!
@ ./broadcast.jl:925
[14] materialize!
@ ./broadcast.jl:883
[15] materialize!
@ ./broadcast.jl:880
[16] #logsoftmax!#202
@ ~/.julia/packages/NNlib/CGMj3/src/softmax.jl:117
Stacktrace:
[1] julia_error(msg::String, val::Ptr{…}, errtype::Enzyme.API.ErrorType, data::Ptr{…}, data2::Ptr{…}, B::Ptr{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/g1jMR/src/errors.jl:384
[2] julia_error(cstr::Cstring, val::Ptr{…}, errtype::Enzyme.API.ErrorType, data::Ptr{…}, data2::Ptr{…}, B::Ptr{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/g1jMR/src/errors.jl:210
[3] EnzymeCreatePrimalAndGradient(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{…}, TA::Enzyme.TypeAnalysis, returnValue::Bool, dretUsed::Bool, mode::Enzyme.API.CDerivativeMode, runtimeActivity::Bool, width::Int64, additionalArg::Ptr{…}, forceAnonymousTape::Bool, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{…}, augmented::Ptr{…}, atomicAdd::Bool)
@ Enzyme.API ~/.julia/packages/Enzyme/g1jMR/src/api.jl:269
[4] enzyme!(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::NTuple{…} where N, returnPrimal::Bool, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/g1jMR/src/compiler.jl:1745
[5] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/g1jMR/src/compiler.jl:4655
[6] codegen
@ ~/.julia/packages/Enzyme/g1jMR/src/compiler.jl:3441 [inlined]
[7] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/g1jMR/src/compiler.jl:5515
[8] _thunk
@ ~/.julia/packages/Enzyme/g1jMR/src/compiler.jl:5515 [inlined]
[9] cached_compilation
@ ~/.julia/packages/Enzyme/g1jMR/src/compiler.jl:5567 [inlined]
[10] thunkbase(mi::Core.MethodInstance, World::UInt64, FA::Type{…}, A::Type{…}, TT::Type, Mode::Enzyme.API.CDerivativeMode, width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, edges::Vector{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/g1jMR/src/compiler.jl:5678
[11] thunk_generator(world::UInt64, source::LineNumberNode, FA::Type, A::Type, TT::Type, Mode::Enzyme.API.CDerivativeMode, Width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, self::Any, fakeworld::Any, fa::Type, a::Type, tt::Type, mode::Type, width::Type, modifiedbetween::Type, returnprimal::Type, shadowinit::Type, abi::Type, erriffuncwritten::Type, runtimeactivity::Type)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/g1jMR/src/compiler.jl:5863
[12] autodiff
@ ~/.julia/packages/Enzyme/g1jMR/src/Enzyme.jl:485 [inlined]
[13] autodiff
@ ~/.julia/packages/Enzyme/g1jMR/src/Enzyme.jl:524 [inlined]
[14] macro expansion
@ ~/.julia/packages/Enzyme/g1jMR/src/sugar.jl:275 [inlined]
[15] gradient(rm::ReverseMode{…}, f::typeof(forward), x::Dense{…}, args::Vector{…})
@ Enzyme ~/.julia/packages/Enzyme/g1jMR/src/sugar.jl:262
[16] top-level scope
@ REPL[35]:1
Some type information was truncated. Use `show(err)` to see complete types.
Runs fine on my machine. What versions of julia, Flux, and Enzyme are you on?
Runs fine on my machine. What versions of julia, Flux, and Enzyme are you on?
julia> versioninfo()
Julia Version 1.11.4
Commit 8561cc3d68d (2025-03-10 11:36 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 96 × AMD EPYC 7571
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver1)
Threads: 1 default, 0 interactive, 1 GC (on 96 virtual cores
Status `~/vscode/enzyme_test/Project.toml`
[7da242da] Enzyme v0.13.35
[587475ba] Flux v0.16.3
Btw, it happens when Flux.logsoftmax is used.
I took from this example,
julia> using Flux, Enzyme
julia> model = Chain(Dense(28^2 => 32, sigmoid), Dense(32 => 10), softmax);
# from model zoo
julia> dup_model = Enzyme.Duplicated(model) # this allocates space for the
gradient
Duplicated(
Chain(
Dense(784 => 32, σ), # 25_120 parameters
Dense(32 => 10), # 330 parameters
NNlib.softmax,
),
# norm(∇) ≈ 0.0f0
) # Total: 4 arrays, 25_450 parameters, 199.391 KiB.
julia> x1 = randn32(28*28, 1); # fake image
julia> y1 = [i==3 for i in 0:9]; # fake label
julia> grads_f = Flux.gradient((m,x,y) -> sum(abs2, m(x) .- y), dup_model,
Const(x1), Const(y1)) # uses Enzyme
((layers = ((weight = Float32[-0.010354728 0.032972857 …
-0.0014538406], σ = nothing), nothing),), nothing, nothing)
This MWE confuses me. What is Flux.gradient and why does it take Enzyme data-structures?
https://fluxml.ai/Flux.jl/stable/reference/training/enzyme/#Flux.gradient-Tuple%7BAny,%20Vararg%7BUnion%7BEnzymeCore.Const,%20EnzymeCore.Duplicated%7D%7D%7D
The code that Flux.gradient(f, ::Duplicated) ends up calling is here, make_zero! then autodiff:
https://github.com/FluxML/Flux.jl/blob/0e36af98f6fc5b7f3c95fe819a02172cfaaaf777/ext/FluxEnzymeExt/FluxEnzymeExt.jl#L44-L51
we recently released patches that fix similar error messages, can you check if it persists?
On 1.10 and 1.11 this doesn't error for me. If it persists, please reopen!