Enzyme.jl
Enzyme.jl copied to clipboard
Recursive store of symbol
As on slack, this:
using Flux, Enzyme
m = Chain(MultiHeadAttention(5; nheads=1), first)
x = Flux.onehotbatch([1; 2; 3; 2;;], 1:5)
Flux.logitcrossentropy(m(x), x) isa Float32
Enzyme.gradient(ReverseWithPrimal, m -> Flux.logitcrossentropy(m(x), x), m)
gives the following warnings + error:
┌ Warning: active variables passed by value to jl_new_task are not yet supported
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/1cGqD/src/utils.jl:52
┌ Warning: active variables passed by value to jl_new_task are not yet supported
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/1cGqD/src/utils.jl:52
┌ Warning: active variables passed by value to jl_new_task are not yet supported
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/1cGqD/src/utils.jl:52
ERROR: Error handling recursive stores for Symbol which has a fieldcount of 0
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] create_recursive_stores(B::LLVM.IRBuilder, Ty::DataType, prev::LLVM.Value)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/3VNOP/src/compiler.jl:622
[3] create_recursive_stores(B::LLVM.IRBuilder, Ty::DataType, prev::LLVM.Value)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/3VNOP/src/compiler.jl:667
[4] shadow_alloc_rewrite(V::Ptr{…}, gutils::Ptr{…}, Orig::Ptr{…}, idx::UInt64, prev::Ptr{…}, used::UInt8)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/3VNOP/src/compiler.jl:739
[5] EnzymeCreateAugmentedPrimal(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{…}, TA::Enzyme.TypeAnalysis, returnUsed::Bool, shadowReturnUsed::Bool, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{…}, forceAnonymousTape::Bool, runtimeActivity::Bool, width::Int64, atomicAdd::Bool)
@ Enzyme.API ~/.julia/packages/Enzyme/3VNOP/src/api.jl:411
[6] enzyme!(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::NTuple{…} where N, returnPrimal::Bool, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/3VNOP/src/compiler.jl:1663
[7] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/3VNOP/src/compiler.jl:4664
[8] codegen
@ ~/.julia/packages/Enzyme/3VNOP/src/compiler.jl:3450 [inlined]
[9] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/3VNOP/src/compiler.jl:5528
[10] _thunk
@ ~/.julia/packages/Enzyme/3VNOP/src/compiler.jl:5528 [inlined]
[11] cached_compilation
@ ~/.julia/packages/Enzyme/3VNOP/src/compiler.jl:5580 [inlined]
[12] thunkbase(mi::Core.MethodInstance, World::UInt64, FA::Type{…}, A::Type{…}, TT::Type, Mode::Enzyme.API.CDerivativeMode, width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, edges::Vector{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/3VNOP/src/compiler.jl:5691
[13] thunk_generator(world::UInt64, source::LineNumberNode, FA::Type, A::Type, TT::Type, Mode::Enzyme.API.CDerivativeMode, Width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, self::Any, fakeworld::Any, fa::Type, a::Type, tt::Type, mode::Type, width::Type, modifiedbetween::Type, returnprimal::Type, shadowinit::Type, abi::Type, erriffuncwritten::Type, runtimeactivity::Type)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/3VNOP/src/compiler.jl:5876
[14] runtime_generic_augfwd(activity::Type{…}, runtimeActivity::Val{…}, width::Val{…}, ModifiedBetween::Val{…}, RT::Val{…}, f::Chain{…}, df::Chain{…}, primal_1::OneHotArrays.OneHotArray{…}, shadow_1_1::Nothing)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/3VNOP/src/rules/jitrules.jl:445
[15] #5
@ ./REPL[5]:1 [inlined]
[16] augmented_julia__5_13696_inner_1wrap
@ ./REPL[5]:0
[17] macro expansion
@ ~/.julia/packages/Enzyme/3VNOP/src/compiler.jl:5458 [inlined]
[18] enzyme_call
@ ~/.julia/packages/Enzyme/3VNOP/src/compiler.jl:4992 [inlined]
[19] AugmentedForwardThunk
@ ~/.julia/packages/Enzyme/3VNOP/src/compiler.jl:4928 [inlined]
[20] autodiff
@ ~/.julia/packages/Enzyme/3VNOP/src/Enzyme.jl:396 [inlined]
[21] autodiff
@ ~/.julia/packages/Enzyme/3VNOP/src/Enzyme.jl:524 [inlined]
[22] macro expansion
@ ~/.julia/packages/Enzyme/3VNOP/src/sugar.jl:324 [inlined]
[23] gradient(::ReverseMode{false, false, FFIABI, false, false}, ::var"#5#6", ::Chain{Tuple{…}})
@ Enzyme ~/.julia/packages/Enzyme/3VNOP/src/sugar.jl:262
[24] top-level scope
@ REPL[5]:1
Some type information was truncated. Use `show(err)` to see complete types.
(@v1.11) pkg> st Enzyme Flux
Status `~/.julia/environments/v1.11/Project.toml`
[7da242da] Enzyme v0.13.41
[587475ba] Flux v0.16.3
julia> versioninfo()
Julia Version 1.11.5
Commit 760b2e5b739 (2025-04-14 06:53 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: macOS (arm64-apple-darwin24.0.0)
CPU: 11 × Apple M3 Pro
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, apple-m2)
Threads: 4 default, 0 interactive, 2 GC (on 5 virtual cores)
Environment:
JULIA_NUM_THREADS = 4
Aside, Zygote fails on https://github.com/FluxML/Zygote.jl/issues/1567 which can be solved via piracy, if we want an answer to compare to:
julia> Flux.gradient(m -> Flux.logitcrossentropy(m(x), x), m)
ERROR: MethodError: no method matching reshape(::Nothing, ::Tuple{Int64, Int64, Int64})
The function `reshape` exists, but no method is defined for this combination of argument types.
julia> Base.reshape(::Nothing, ::Tuple{Int, Int, Int}) = nothing
julia> Flux.gradient(m -> Flux.logitcrossentropy(m(x), x), m)
((layers = ((nheads = nothing, q_proj = (weight = Float32[0.0025089695 0.031038722 … 0.0
I am sure this could be boiled down further, but I'm not certain I have time. I believe the core is likely to be NNlib.batched_mul, which has an @threads loop over mul!. Adding a rule to that function does seem to make this MWE work. (I have a branch somewhere, insufficiently tested. After which my real problem hits a different error.)
Not really a complaint, but I confess I find these errors extremely opaque. To first approximation everything fails all the time, for me. I would love to help make it work but it's very slow to iteratively isolate things. (I tried Enzyme on a real problem yesterday, and today timed reloading & updating -- 16 minutes of compilation time to get to the above message.)