Enzyme.jl icon indicating copy to clipboard operation
Enzyme.jl copied to clipboard

Enzyme Assertion with ComponentArrays

Open avik-pal opened this issue 9 months ago • 9 comments

using Lux, ComponentArrays, Random, Enzyme

rng = Random.MersenneTwister(1234)

# Define a basic neural network structure
NN = Lux.Chain(Lux.Dense(5 => 5, tanh), Lux.Dense(5 => 1))

# Setup the network
ps, st = Lux.setup(rng, NN)

# Test the intialized network with some input values
xtest = [0.1, 0.2, 0.3, 0.4, 0.5]
dx = zeros(size(xtest)[1])

Enzyme.API.runtimeActivity!(true)

function test_function(NN, x, ps, st)
    y, _ = NN(x, ps, st)
    return sum(y)
end

@time autodiff(
    Reverse, test_function, Active, Const(NN), Duplicated(xtest, dx), Const(ps), Const(st))  # Works without CA

ps_ca = ComponentArray(ps)

@time autodiff(
    Reverse, test_function, Active, Const(NN), Duplicated(xtest, dx), Const(ps_ca), Const(st))

If the parameters are NamedTuple it works as expected, changing them to ComponentArray gives an Assertion Error

Stacktrace
julia: /workspace/srcdir/Enzyme/enzyme/Enzyme/GradientUtils.cpp:1624: virtual llvm::Value* GradientUtils::unwrapM(llvm::Value*, llvm::IRBuilder<>&, const ValueToValueMapTy&, UnwrapMode, llvm::BasicBlock*, bool): Assertion `unwrapMode != UnwrapMode::LegalFullUnwrap' failed.

[765106] signal (6.-6): Aborted
in expression starting at REPL[14]:1
pthread_kill at /lib/x86_64-linux-gnu/libc.so.6 (unknown line)
raise at /lib/x86_64-linux-gnu/libc.so.6 (unknown line)
abort at /lib/x86_64-linux-gnu/libc.so.6 (unknown line)
unknown function (ip: 0x70b987c267da)
__assert_fail at /lib/x86_64-linux-gnu/libc.so.6 (unknown line)
unwrapM at /workspace/srcdir/Enzyme/enzyme/Enzyme/GradientUtils.cpp:1624
unwrapM at /workspace/srcdir/Enzyme/enzyme/Enzyme/GradientUtils.cpp:1150
unwrapM at /workspace/srcdir/Enzyme/enzyme/Enzyme/GradientUtils.cpp:1144
unwrapM at /workspace/srcdir/Enzyme/enzyme/Enzyme/GradientUtils.cpp:1063
unwrapM at /workspace/srcdir/Enzyme/enzyme/Enzyme/GradientUtils.cpp:1063
unwrapM at /workspace/srcdir/Enzyme/enzyme/Enzyme/GradientUtils.cpp:1063
unwrapM at /workspace/srcdir/Enzyme/enzyme/Enzyme/GradientUtils.cpp:1063
unwrapM at /workspace/srcdir/Enzyme/enzyme/Enzyme/GradientUtils.cpp:1327
unwrapM at /workspace/srcdir/Enzyme/enzyme/Enzyme/GradientUtils.cpp:1063
unwrapM at /workspace/srcdir/Enzyme/enzyme/Enzyme/GradientUtils.cpp:1088
unwrapM at /workspace/srcdir/Enzyme/enzyme/Enzyme/GradientUtils.cpp:1144
unwrapM at /workspace/srcdir/Enzyme/enzyme/Enzyme/GradientUtils.cpp:1063
unwrapM at /workspace/srcdir/Enzyme/enzyme/Enzyme/GradientUtils.cpp:1063
unwrapM at /workspace/srcdir/Enzyme/enzyme/Enzyme/GradientUtils.cpp:1066
unwrapM at /workspace/srcdir/Enzyme/enzyme/Enzyme/GradientUtils.cpp:1066
unwrapM at /workspace/srcdir/Enzyme/enzyme/Enzyme/GradientUtils.cpp:1063
unwrapM at /workspace/srcdir/Enzyme/enzyme/Enzyme/GradientUtils.cpp:1173
freeCache at /workspace/srcdir/Enzyme/enzyme/Enzyme/DiffeGradientUtils.cpp:754
createCacheForScope at /workspace/srcdir/Enzyme/enzyme/Enzyme/CacheUtility.cpp:1004
ensureLookupCached at /workspace/srcdir/Enzyme/enzyme/Enzyme/GradientUtils.cpp:2366 [inlined]
ensureLookupCached at /workspace/srcdir/Enzyme/enzyme/Enzyme/GradientUtils.cpp:2352
lookupM at /workspace/srcdir/Enzyme/enzyme/Enzyme/GradientUtils.cpp:7314
unwrapM at /workspace/srcdir/Enzyme/enzyme/Enzyme/GradientUtils.cpp:930
lookupM at /workspace/srcdir/Enzyme/enzyme/Enzyme/GradientUtils.cpp:6537
lookup at /workspace/srcdir/Enzyme/enzyme/Enzyme/AdjointGenerator.h:2199
visitBinaryOperator at /workspace/srcdir/Enzyme/build/Enzyme/BinopDerivatives.inc:613
visit at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/InstVisitor.h:111 [inlined]
CreatePrimalAndGradient at /workspace/srcdir/Enzyme/enzyme/Enzyme/EnzymeLogic.cpp:4404
EnzymeCreatePrimalAndGradient at /workspace/srcdir/Enzyme/enzyme/Enzyme/CApi.cpp:615
EnzymeCreatePrimalAndGradient at /home/avikpal/.julia/packages/Enzyme/2FwRI/src/api.jl:154
unknown function (ip: 0x70b8e4dffacb)
_jl_invoke at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined]
ijl_apply_generic at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:3077
enzyme! at /home/avikpal/.julia/packages/Enzyme/2FwRI/src/compiler.jl:3166
unknown function (ip: 0x70b8e4dbcc98)
_jl_invoke at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined]
ijl_apply_generic at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:3077
#codegen#509 at /home/avikpal/.julia/packages/Enzyme/2FwRI/src/compiler.jl:5070
codegen at /home/avikpal/.julia/packages/Enzyme/2FwRI/src/compiler.jl:4477 [inlined]
_thunk at /home/avikpal/.julia/packages/Enzyme/2FwRI/src/compiler.jl:5755
_thunk at /home/avikpal/.julia/packages/Enzyme/2FwRI/src/compiler.jl:5755 [inlined]
cached_compilation at /home/avikpal/.julia/packages/Enzyme/2FwRI/src/compiler.jl:5793 [inlined]
#554 at /home/avikpal/.julia/packages/Enzyme/2FwRI/src/compiler.jl:5859
#JuliaContext#149 at /home/avikpal/.julia/packages/GPUCompiler/kqxyC/src/driver.jl:52
unknown function (ip: 0x70b8e4da7466)
_jl_invoke at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined]
ijl_apply_generic at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:3077
JuliaContext at /home/avikpal/.julia/packages/GPUCompiler/kqxyC/src/driver.jl:42
#s2027#553 at /home/avikpal/.julia/packages/Enzyme/2FwRI/src/compiler.jl:5811 [inlined]
#s2027#553 at ./none:0
_jl_invoke at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined]
ijl_apply_generic at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:3077
GeneratedFunctionStub at ./boot.jl:602
_jl_invoke at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined]
ijl_apply_generic at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:3077
jl_call_staged at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/method.c:540
ijl_code_for_staged at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/method.c:593
get_staged at ./compiler/utilities.jl:123
retrieve_code_info at ./compiler/utilities.jl:135 [inlined]
InferenceState at ./compiler/inferencestate.jl:430
typeinf_ext at ./compiler/typeinfer.jl:1049
typeinf_ext_toplevel at ./compiler/typeinfer.jl:1082
typeinf_ext_toplevel at ./compiler/typeinfer.jl:1078
jfptr_typeinf_ext_toplevel_35682.1 at /home/avikpal/.julia/juliaup/julia-1.10.3+0.x64.linux.gnu/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined]
ijl_apply_generic at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:3077
jl_apply at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/julia.h:1982 [inlined]
jl_type_infer at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:394
jl_generate_fptr_impl at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/jitlayers.cpp:504
jl_compile_method_internal at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:2481 [inlined]
jl_compile_method_internal at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:2368
_jl_invoke at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:2887 [inlined]
ijl_apply_generic at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:3077
runtime_generic_augfwd at /home/avikpal/.julia/packages/Enzyme/2FwRI/src/rules/jitrules.jl:175
unknown function (ip: 0x70b8e4996129)
_jl_invoke at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined]
ijl_apply_generic at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:3077
generic_matmatmul! at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:783
mul! at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:263 [inlined]
mul! at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:237 [inlined]
__matmul! at /mnt/research/ongoing/lux/LuxLib.jl/src/impl/fused_dense.jl:5 [inlined]
__fused_dense_bias_activation_impl at /mnt/research/ongoing/lux/LuxLib.jl/src/impl/fused_dense.jl:30 [inlined]
fused_dense_bias_activation at /mnt/research/ongoing/lux/LuxLib.jl/src/api/dense.jl:46 [inlined]
fused_dense_bias_activation at /mnt/research/ongoing/lux/LuxLib.jl/src/api/dense.jl:38 [inlined]
Dense at /mnt/research/ongoing/lux/Lux.jl/src/layers/basic.jl:218 [inlined]
Dense at /mnt/research/ongoing/lux/Lux.jl/src/layers/basic.jl:214 [inlined]
apply at /home/avikpal/.julia/packages/LuxCore/qiHPC/src/LuxCore.jl:179
macro expansion at ./abstractarray.jl:0 [inlined]
applychain at /mnt/research/ongoing/lux/Lux.jl/src/layers/containers.jl:478 [inlined]
Chain at /mnt/research/ongoing/lux/Lux.jl/src/layers/containers.jl:476 [inlined]
test_function at ./REPL[10]:2 [inlined]
test_function at ./REPL[10]:0 [inlined]
diffejulia_test_function_7588_inner_1wrap at ./REPL[10]:0
macro expansion at /home/avikpal/.julia/packages/Enzyme/2FwRI/src/compiler.jl:5703 [inlined]
enzyme_call at /home/avikpal/.julia/packages/Enzyme/2FwRI/src/compiler.jl:5381 [inlined]
CombinedAdjointThunk at /home/avikpal/.julia/packages/Enzyme/2FwRI/src/compiler.jl:5260 [inlined]
autodiff at /home/avikpal/.julia/packages/Enzyme/2FwRI/src/Enzyme.jl:291 [inlined]
autodiff at /home/avikpal/.julia/packages/Enzyme/2FwRI/src/Enzyme.jl:303
unknown function (ip: 0x70b8e4994ae6)
_jl_invoke at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined]
ijl_apply_generic at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:3077
top-level scope at ./timing.jl:279
jl_toplevel_eval_flex at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/toplevel.c:925
jl_toplevel_eval_flex at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/toplevel.c:877
eval_body at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/interpreter.c:579
eval_body at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/interpreter.c:544
jl_interpret_toplevel_thunk at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/interpreter.c:775
jl_toplevel_eval_flex at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/toplevel.c:934
jl_toplevel_eval_flex at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/toplevel.c:877
ijl_toplevel_eval_in at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/toplevel.c:985
eval at ./boot.jl:385 [inlined]
eval_user_input at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/REPL.jl:150
repl_backend_loop at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/REPL.jl:246
#start_repl_backend#46 at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/REPL.jl:231
start_repl_backend at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/REPL.jl:228
_jl_invoke at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined]
ijl_apply_generic at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:3077
#run_repl#59 at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/REPL.jl:389
run_repl at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/REPL.jl:375
jfptr_run_repl_91734.1 at /home/avikpal/.julia/juliaup/julia-1.10.3+0.x64.linux.gnu/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined]
ijl_apply_generic at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:3077
#1013 at ./client.jl:432
jfptr_YY.1013_82700.1 at /home/avikpal/.julia/juliaup/julia-1.10.3+0.x64.linux.gnu/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined]
ijl_apply_generic at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:3077
jl_apply at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/julia.h:1982 [inlined]
jl_f__call_latest at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/builtins.c:812
#invokelatest#2 at ./essentials.jl:892 [inlined]
invokelatest at ./essentials.jl:889 [inlined]
run_main_repl at ./client.jl:416
exec_options at ./client.jl:333
_start at ./client.jl:552
jfptr__start_82726.1 at /home/avikpal/.julia/juliaup/julia-1.10.3+0.x64.linux.gnu/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined]
ijl_apply_generic at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/gf.c:3077
jl_apply at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/julia.h:1982 [inlined]
true_main at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/jlapi.c:582
jl_repl_entrypoint at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/src/jlapi.c:731
main at /cache/build/builder-amdci4-2/julialang/julia-release-1-dot-10/cli/loader_exe.c:58
unknown function (ip: 0x70b987c2814f)
__libc_start_main at /lib/x86_64-linux-gnu/libc.so.6 (unknown line)
unknown function (ip: 0x4010b8)
Allocations: 103393830 (Pool: 103275172; Big: 118658); GC: 72
[1]    765106 IOT instruction (core dumped)  julia --depwarn=yes --threads=auto --project=.

avik-pal avatar May 14 '24 20:05 avik-pal

Can you minimize the error by chance?

wsmoses avatar May 14 '24 22:05 wsmoses

@avik-pal would you be able to reduce this? If so I can try to fix this later today.

wsmoses avatar May 16 '24 17:05 wsmoses

Tried minimizing but https://github.com/EnzymeAD/Enzyme.jl/issues/1451#issuecomment-2116218381 is where I could get, but that is a different error

avik-pal avatar May 16 '24 21:05 avik-pal

@avik-pal with the linked comment resolved, any chance you can see if this is reducible?

wsmoses avatar May 24 '24 04:05 wsmoses

bumping here @avik-pal

wsmoses avatar Jun 06 '24 20:06 wsmoses

bump @avik-pal

wsmoses avatar Jun 15 '24 19:06 wsmoses

bump @avik-pal

wsmoses avatar Jul 14 '24 17:07 wsmoses

Strangely enough if xtest is a Matrix instead of a vector, it works

using Lux
using LinearAlgebra, ComponentArrays, Random, Enzyme

rng = Random.MersenneTwister(1234)

# Define a basic neural network structure
NN = Lux.Dense(5 => 5, tanh)

# Setup the network
ps, st = Lux.setup(rng, NN)

# Test the intialized network with some input values
xtest = rand(Float32, 5, 1)  # ---> matrix not vector then works
dx = zero.(xtest)

Enzyme.API.runtimeActivity!(true)

function test_function(NN, x, ps, st)
    y, _ = NN(x, ps, st)
    return sum(y)
end

ps_ca = ComponentArray(ps)

@time autodiff(Reverse, test_function, Active, Const(NN),
    Duplicated(xtest, dx), Const(ps_ca), Const(st))

avik-pal avatar Jul 14 '24 19:07 avik-pal

Here is a reduced version using just LinearAlgebra and ComponentArrays

using LinearAlgebra, ComponentArrays, Random, Enzyme

rng = Random.MersenneTwister(1234)

ps = (; weight=rand(Float32, 5, 5), bias=rand(Float32, 5))

xtest = rand(Float64, 5)
dx = zero.(xtest)

Enzyme.API.runtimeActivity!(true)

function test_function(x, ps)
    x_ = reshape(x, :, 1)
    y = muladd(ps.weight, x_, ps.bias)
    return sum(y)
end

ps_ca = ComponentArray(ps)

@time test_function(xtest, ps)
@time autodiff(Reverse, test_function, Active, Duplicated(xtest, dx), Const(ps_ca))

It might have something to do with mixed-precision, I don't get any assertion if xtest is Float32

avik-pal avatar Jul 14 '24 19:07 avik-pal