Enzyme.jl
Enzyme.jl copied to clipboard
Enzyme Assertion with ComponentArrays
using Lux, ComponentArrays, Random, Enzyme
rng = Random.MersenneTwister(1234)
# Define a basic neural network structure
NN = Lux.Chain(Lux.Dense(5 => 5, tanh), Lux.Dense(5 => 1))
# Setup the network
ps, st = Lux.setup(rng, NN)
# Test the intialized network with some input values
xtest = [0.1, 0.2, 0.3, 0.4, 0.5]
dx = zeros(size(xtest)[1])
Enzyme.API.runtimeActivity!(true)
function test_function(NN, x, ps, st)
y, _ = NN(x, ps, st)
return sum(y)
end
@time autodiff(
Reverse, test_function, Active, Const(NN), Duplicated(xtest, dx), Const(ps), Const(st)) # Works without CA
ps_ca = ComponentArray(ps)
@time autodiff(
Reverse, test_function, Active, Const(NN), Duplicated(xtest, dx), Const(ps_ca), Const(st))
If the parameters are NamedTuple it works as expected, changing them to ComponentArray gives an Assertion Error
Stacktrace
julia: /workspace/srcdir/Enzyme/enzyme/Enzyme/GradientUtils.cpp:1624: virtual llvm::Value* GradientUtils::unwrapM(llvm::Value*, llvm::IRBuilder<>&, const ValueToValueMapTy&, UnwrapMode, llvm::BasicBlock*, bool): Assertion `unwrapMode != UnwrapMode::LegalFullUnwrap' failed.
[765106] signal (6.-6): Aborted
in expression starting at REPL[14]:1
pthread_kill at /lib/x86_64-linux-gnu/libc.so.6 (unknown line)
raise at /lib/x86_64-linux-gnu/libc.so.6 (unknown line)
abort at /lib/x86_64-linux-gnu/libc.so.6 (unknown line)
unknown function (ip: 0x70b987c267da)
__assert_fail at /lib/x86_64-linux-gnu/libc.so.6 (unknown line)
unwrapM at /workspace/srcdir/Enzyme/enzyme/Enzyme/GradientUtils.cpp:1624
unwrapM at /workspace/srcdir/Enzyme/enzyme/Enzyme/GradientUtils.cpp:1150
unwrapM at /workspace/srcdir/Enzyme/enzyme/Enzyme/GradientUtils.cpp:1144
unwrapM at /workspace/srcdir/Enzyme/enzyme/Enzyme/GradientUtils.cpp:1063
unwrapM at /workspace/srcdir/Enzyme/enzyme/Enzyme/GradientUtils.cpp:1063
unwrapM at /workspace/srcdir/Enzyme/enzyme/Enzyme/GradientUtils.cpp:1063
unwrapM at /workspace/srcdir/Enzyme/enzyme/Enzyme/GradientUtils.cpp:1063
unwrapM at /workspace/srcdir/Enzyme/enzyme/Enzyme/GradientUtils.cpp:1327
unwrapM at /workspace/srcdir/Enzyme/enzyme/Enzyme/GradientUtils.cpp:1063
unwrapM at /workspace/srcdir/Enzyme/enzyme/Enzyme/GradientUtils.cpp:1088
unwrapM at /workspace/srcdir/Enzyme/enzyme/Enzyme/GradientUtils.cpp:1144
unwrapM at /workspace/srcdir/Enzyme/enzyme/Enzyme/GradientUtils.cpp:1063
unwrapM at /workspace/srcdir/Enzyme/enzyme/Enzyme/GradientUtils.cpp:1063
unwrapM at /workspace/srcdir/Enzyme/enzyme/Enzyme/GradientUtils.cpp:1066
unwrapM at /workspace/srcdir/Enzyme/enzyme/Enzyme/GradientUtils.cpp:1066
unwrapM at /workspace/srcdir/Enzyme/enzyme/Enzyme/GradientUtils.cpp:1063
unwrapM at /workspace/srcdir/Enzyme/enzyme/Enzyme/GradientUtils.cpp:1173
freeCache at /workspace/srcdir/Enzyme/enzyme/Enzyme/DiffeGradientUtils.cpp:754
createCacheForScope at /workspace/srcdir/Enzyme/enzyme/Enzyme/CacheUtility.cpp:1004
ensureLookupCached at /workspace/srcdir/Enzyme/enzyme/Enzyme/GradientUtils.cpp:2366 [inlined]
ensureLookupCached at /workspace/srcdir/Enzyme/enzyme/Enzyme/GradientUtils.cpp:2352
lookupM at /workspace/srcdir/Enzyme/enzyme/Enzyme/GradientUtils.cpp:7314
unwrapM at /workspace/srcdir/Enzyme/enzyme/Enzyme/GradientUtils.cpp:930
lookupM at /workspace/srcdir/Enzyme/enzyme/Enzyme/GradientUtils.cpp:6537
lookup at /workspace/srcdir/Enzyme/enzyme/Enzyme/AdjointGenerator.h:2199
visitBinaryOperator at /workspace/srcdir/Enzyme/build/Enzyme/BinopDerivatives.inc:613
visit at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/InstVisitor.h:111 [inlined]
CreatePrimalAndGradient at /workspace/srcdir/Enzyme/enzyme/Enzyme/EnzymeLogic.cpp:4404
EnzymeCreatePrimalAndGradient at /workspace/srcdir/Enzyme/enzyme/Enzyme/CApi.cpp:615
EnzymeCreatePrimalAndGradient at /home/avikpal/.julia/packages/Enzyme/2FwRI/src/api.jl:154
unknown function (ip: 0x70b8e4dffacb)
_jl_invoke at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined]
ijl_apply_generic at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:3077
enzyme! at /home/avikpal/.julia/packages/Enzyme/2FwRI/src/compiler.jl:3166
unknown function (ip: 0x70b8e4dbcc98)
_jl_invoke at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined]
ijl_apply_generic at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:3077
#codegen#509 at /home/avikpal/.julia/packages/Enzyme/2FwRI/src/compiler.jl:5070
codegen at /home/avikpal/.julia/packages/Enzyme/2FwRI/src/compiler.jl:4477 [inlined]
_thunk at /home/avikpal/.julia/packages/Enzyme/2FwRI/src/compiler.jl:5755
_thunk at /home/avikpal/.julia/packages/Enzyme/2FwRI/src/compiler.jl:5755 [inlined]
cached_compilation at /home/avikpal/.julia/packages/Enzyme/2FwRI/src/compiler.jl:5793 [inlined]
#554 at /home/avikpal/.julia/packages/Enzyme/2FwRI/src/compiler.jl:5859
#JuliaContext#149 at /home/avikpal/.julia/packages/GPUCompiler/kqxyC/src/driver.jl:52
unknown function (ip: 0x70b8e4da7466)
_jl_invoke at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined]
ijl_apply_generic at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:3077
JuliaContext at /home/avikpal/.julia/packages/GPUCompiler/kqxyC/src/driver.jl:42
#s2027#553 at /home/avikpal/.julia/packages/Enzyme/2FwRI/src/compiler.jl:5811 [inlined]
#s2027#553 at ./none:0
_jl_invoke at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined]
ijl_apply_generic at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:3077
GeneratedFunctionStub at ./boot.jl:602
_jl_invoke at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined]
ijl_apply_generic at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:3077
jl_call_staged at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/method.c:540
ijl_code_for_staged at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/method.c:593
get_staged at ./compiler/utilities.jl:123
retrieve_code_info at ./compiler/utilities.jl:135 [inlined]
InferenceState at ./compiler/inferencestate.jl:430
typeinf_ext at ./compiler/typeinfer.jl:1049
typeinf_ext_toplevel at ./compiler/typeinfer.jl:1082
typeinf_ext_toplevel at ./compiler/typeinfer.jl:1078
jfptr_typeinf_ext_toplevel_35682.1 at /home/avikpal/.julia/juliaup/julia-1.10.3+0.x64.linux.gnu/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined]
ijl_apply_generic at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:3077
jl_apply at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/julia.h:1982 [inlined]
jl_type_infer at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:394
jl_generate_fptr_impl at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/jitlayers.cpp:504
jl_compile_method_internal at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:2481 [inlined]
jl_compile_method_internal at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:2368
_jl_invoke at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:2887 [inlined]
ijl_apply_generic at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:3077
runtime_generic_augfwd at /home/avikpal/.julia/packages/Enzyme/2FwRI/src/rules/jitrules.jl:175
unknown function (ip: 0x70b8e4996129)
_jl_invoke at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined]
ijl_apply_generic at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:3077
generic_matmatmul! at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:783
mul! at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:263 [inlined]
mul! at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:237 [inlined]
__matmul! at /mnt/research/ongoing/lux/LuxLib.jl/src/impl/fused_dense.jl:5 [inlined]
__fused_dense_bias_activation_impl at /mnt/research/ongoing/lux/LuxLib.jl/src/impl/fused_dense.jl:30 [inlined]
fused_dense_bias_activation at /mnt/research/ongoing/lux/LuxLib.jl/src/api/dense.jl:46 [inlined]
fused_dense_bias_activation at /mnt/research/ongoing/lux/LuxLib.jl/src/api/dense.jl:38 [inlined]
Dense at /mnt/research/ongoing/lux/Lux.jl/src/layers/basic.jl:218 [inlined]
Dense at /mnt/research/ongoing/lux/Lux.jl/src/layers/basic.jl:214 [inlined]
apply at /home/avikpal/.julia/packages/LuxCore/qiHPC/src/LuxCore.jl:179
macro expansion at ./abstractarray.jl:0 [inlined]
applychain at /mnt/research/ongoing/lux/Lux.jl/src/layers/containers.jl:478 [inlined]
Chain at /mnt/research/ongoing/lux/Lux.jl/src/layers/containers.jl:476 [inlined]
test_function at ./REPL[10]:2 [inlined]
test_function at ./REPL[10]:0 [inlined]
diffejulia_test_function_7588_inner_1wrap at ./REPL[10]:0
macro expansion at /home/avikpal/.julia/packages/Enzyme/2FwRI/src/compiler.jl:5703 [inlined]
enzyme_call at /home/avikpal/.julia/packages/Enzyme/2FwRI/src/compiler.jl:5381 [inlined]
CombinedAdjointThunk at /home/avikpal/.julia/packages/Enzyme/2FwRI/src/compiler.jl:5260 [inlined]
autodiff at /home/avikpal/.julia/packages/Enzyme/2FwRI/src/Enzyme.jl:291 [inlined]
autodiff at /home/avikpal/.julia/packages/Enzyme/2FwRI/src/Enzyme.jl:303
unknown function (ip: 0x70b8e4994ae6)
_jl_invoke at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined]
ijl_apply_generic at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:3077
top-level scope at ./timing.jl:279
jl_toplevel_eval_flex at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/toplevel.c:925
jl_toplevel_eval_flex at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/toplevel.c:877
eval_body at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/interpreter.c:579
eval_body at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/interpreter.c:544
jl_interpret_toplevel_thunk at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/interpreter.c:775
jl_toplevel_eval_flex at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/toplevel.c:934
jl_toplevel_eval_flex at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/toplevel.c:877
ijl_toplevel_eval_in at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/toplevel.c:985
eval at ./boot.jl:385 [inlined]
eval_user_input at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/REPL.jl:150
repl_backend_loop at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/REPL.jl:246
#start_repl_backend#46 at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/REPL.jl:231
start_repl_backend at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/REPL.jl:228
_jl_invoke at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined]
ijl_apply_generic at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:3077
#run_repl#59 at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/REPL.jl:389
run_repl at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/REPL.jl:375
jfptr_run_repl_91734.1 at /home/avikpal/.julia/juliaup/julia-1.10.3+0.x64.linux.gnu/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined]
ijl_apply_generic at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:3077
#1013 at ./client.jl:432
jfptr_YY.1013_82700.1 at /home/avikpal/.julia/juliaup/julia-1.10.3+0.x64.linux.gnu/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined]
ijl_apply_generic at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:3077
jl_apply at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/julia.h:1982 [inlined]
jl_f__call_latest at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/builtins.c:812
#invokelatest#2 at ./essentials.jl:892 [inlined]
invokelatest at ./essentials.jl:889 [inlined]
run_main_repl at ./client.jl:416
exec_options at ./client.jl:333
_start at ./client.jl:552
jfptr__start_82726.1 at /home/avikpal/.julia/juliaup/julia-1.10.3+0.x64.linux.gnu/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined]
ijl_apply_generic at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:3077
jl_apply at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/julia.h:1982 [inlined]
true_main at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/jlapi.c:582
jl_repl_entrypoint at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/jlapi.c:731
main at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/cli/loader_exe.c:58
unknown function (ip: 0x70b987c2814f)
__libc_start_main at /lib/x86_64-linux-gnu/libc.so.6 (unknown line)
unknown function (ip: 0x4010b8)
Allocations: 103393830 (Pool: 103275172; Big: 118658); GC: 72
[1] 765106 IOT instruction (core dumped) julia --depwarn=yes --threads=auto --project=.
Can you minimize the error by chance?
@avik-pal would you be able to reduce this? If so I can try to fix this later today.
Tried minimizing but https://github.com/EnzymeAD/Enzyme.jl/issues/1451#issuecomment-2116218381 is where I could get, but that is a different error
@avik-pal with the linked comment resolved, any chance you can see if this is reducible?
bumping here @avik-pal
bump @avik-pal
bump @avik-pal
Strangely enough if xtest
is a Matrix instead of a vector, it works
using Lux
using LinearAlgebra, ComponentArrays, Random, Enzyme
rng = Random.MersenneTwister(1234)
# Define a basic neural network structure
NN = Lux.Dense(5 => 5, tanh)
# Setup the network
ps, st = Lux.setup(rng, NN)
# Test the intialized network with some input values
xtest = rand(Float32, 5, 1) # ---> matrix not vector then works
dx = zero.(xtest)
Enzyme.API.runtimeActivity!(true)
function test_function(NN, x, ps, st)
y, _ = NN(x, ps, st)
return sum(y)
end
ps_ca = ComponentArray(ps)
@time autodiff(Reverse, test_function, Active, Const(NN),
Duplicated(xtest, dx), Const(ps_ca), Const(st))
Here is a reduced version using just LinearAlgebra and ComponentArrays
using LinearAlgebra, ComponentArrays, Random, Enzyme
rng = Random.MersenneTwister(1234)
ps = (; weight=rand(Float32, 5, 5), bias=rand(Float32, 5))
xtest = rand(Float64, 5)
dx = zero.(xtest)
Enzyme.API.runtimeActivity!(true)
function test_function(x, ps)
x_ = reshape(x, :, 1)
y = muladd(ps.weight, x_, ps.bias)
return sum(y)
end
ps_ca = ComponentArray(ps)
@time test_function(xtest, ps)
@time autodiff(Reverse, test_function, Active, Duplicated(xtest, dx), Const(ps_ca))
It might have something to do with mixed-precision, I don't get any assertion if xtest is Float32