Lux.jl icon indicating copy to clipboard operation
Lux.jl copied to clipboard

Tracking support for Enzyme.jl

Open avik-pal opened this issue 2 years ago • 14 comments

Opening this issue mostly to track how much of Lux (v0.4.7) is supported by Enzyme (v0.10.4)

  • Lux.Dense is supported
using Lux, Random, Enzyme
rng = Random.default_rng()

function loss_function(model, x, ps, st)
    return sum(Lux.apply(model, x, ps, st)[1])
end

model = Chain(Dense(2 => 4), Dense(4 => 2))
ps, st = Lux.setup(rng, model)
x = randn(rng, Float32, 2, 1)

dps = Lux.fmap(zero, ps)
Enzyme.autodiff(loss_function, Const(model), Const(x), Duplicated(ps, dps), Const(st))
println(dps)
  • Lux.BatchNorm segfaults
using Lux, Random, Enzyme
rng = Random.default_rng()

function loss_function(model, x, ps, st)
    return sum(Lux.apply(model, x, ps, st)[1])
end

model = Chain(Dense(2 => 4), BatchNorm(4), Dense(4 => 2))
ps, st = Lux.setup(rng, model)
x = randn(rng, Float32, 2, 1)

dps = Lux.fmap(zero, ps)
Enzyme.autodiff(loss_function, Const(model), Const(x), Duplicated(ps, dps), Const(st))
println(dps)
Click to expand!
warning: Linking two modules of different target triples: 'bcloader' is 'x86_64-unknown-linux-gnu' whereas 'text' is 'x86_64-pc-linux-gnu'

warning: Linking two modules of different target triples: 'bcloader' is 'x86_64-unknown-linux-gnu' whereas 'text' is 'x86_64-pc-linux-gnu'

warning: Linking two modules of different target triples: 'bcloader' is 'x86_64-unknown-linux-gnu' whereas 'text' is 'x86_64-pc-linux-gnu'

┌ Warning: Using fallback BLAS replacements, performance may be degraded
└ @ Enzyme.Compiler /mnt/julia/packages/GPUCompiler/iaKrd/src/utils.jl:35
┌ Warning: Unknown concrete type
│   tt = "{[]:Pointer}"
│   orig = "  %109 = call nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %2) #33, !dbg !321"
└ @ Enzyme.Compiler /mnt/julia/packages/GPUCompiler/iaKrd/src/utils.jl:35
┌ Warning: Unknown concrete type
│   tt = "{[]:Pointer}"
│   orig = "  %255 = call nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %3) #33, !dbg !504"
└ @ Enzyme.Compiler /mnt/julia/packages/GPUCompiler/iaKrd/src/utils.jl:35
┌ Warning: Unknown concrete type
│   tt = "{[]:Pointer}"
│   orig = "  %478 = call nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %4) #33, !dbg !700"
└ @ Enzyme.Compiler /mnt/julia/packages/GPUCompiler/iaKrd/src/utils.jl:35
┌ Warning: Unknown concrete type
│   tt = "{[]:Pointer}"
│   orig = "  %553 = call nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %5) #33, !dbg !814"
└ @ Enzyme.Compiler /mnt/julia/packages/GPUCompiler/iaKrd/src/utils.jl:35
┌ Warning: Unknown concrete type
│   tt = "{[]:Pointer}"
│   orig = "  %109 = call nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %2) #34, !dbg !323"
└ @ Enzyme.Compiler /mnt/julia/packages/GPUCompiler/iaKrd/src/utils.jl:35
┌ Warning: Unknown concrete type
│   tt = "{[]:Pointer}"
│   orig = "  %255 = call nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %3) #34, !dbg !506"
└ @ Enzyme.Compiler /mnt/julia/packages/GPUCompiler/iaKrd/src/utils.jl:35
┌ Warning: Unknown concrete type
│   tt = "{[]:Pointer}"
│   orig = "  %478 = call nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %4) #34, !dbg !702"
└ @ Enzyme.Compiler /mnt/julia/packages/GPUCompiler/iaKrd/src/utils.jl:35
┌ Warning: Unknown concrete type
│   tt = "{[]:Pointer}"
│   orig = "  %553 = call nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %5) #34, !dbg !816"
└ @ Enzyme.Compiler /mnt/julia/packages/GPUCompiler/iaKrd/src/utils.jl:35

signal (11): Segmentation fault
in expression starting at REPL[33]:1
_ZNK4llvm10AllocaInst14isStaticAllocaEv at /mnt/softwares/julia-nightly/bin/../lib/julia/libLLVM-13jl.so (unknown line)
runOnFunction at /buildworker/worker/package_linux64/build/src/llvm-late-gc-lowering.cpp:2689
_ZN4llvm13FPPassManager13runOnFunctionERNS_8FunctionE at /mnt/softwares/julia-nightly/bin/../lib/julia/libLLVM-13jl.so (unknown line)
_ZN4llvm13FPPassManager11runOnModuleERNS_6ModuleE at /mnt/softwares/julia-nightly/bin/../lib/julia/libLLVM-13jl.so (unknown line)
_ZN4llvm6legacy15PassManagerImpl3runERNS_6ModuleE at /mnt/softwares/julia-nightly/bin/../lib/julia/libLLVM-13jl.so (unknown line)
LLVMRunPassManager at /mnt/softwares/julia-nightly/bin/../lib/julia/libLLVM-13jl.so (unknown line)
LLVMRunPassManager at /mnt/julia/packages/LLVM/WjSQG/lib/13/libLLVM_h.jl:4898 [inlined]
run! at /mnt/julia/packages/LLVM/WjSQG/src/passmanager.jl:39 [inlined]
#55 at /mnt/julia/packages/Enzyme/di3zM/src/compiler/optimize.jl:230
#ModulePassManager#64 at /mnt/julia/packages/LLVM/WjSQG/src/passmanager.jl:33
unknown function (ip: 0x7f4adebb715e)
_jl_invoke at /buildworker/worker/package_linux64/build/src/gf.c:2358 [inlined]
ijl_apply_generic at /buildworker/worker/package_linux64/build/src/gf.c:2540
ModulePassManager at /mnt/julia/packages/LLVM/WjSQG/src/passmanager.jl:31
post_optimze! at /mnt/julia/packages/Enzyme/di3zM/src/compiler/optimize.jl:227 [inlined]
post_optimze! at /mnt/julia/packages/Enzyme/di3zM/src/compiler/optimize.jl:221 [inlined]
_thunk at /mnt/julia/packages/Enzyme/di3zM/src/compiler.jl:4617
unknown function (ip: 0x7f4ade76325d)
_jl_invoke at /buildworker/worker/package_linux64/build/src/gf.c:2358 [inlined]
ijl_apply_generic at /buildworker/worker/package_linux64/build/src/gf.c:2540
cached_compilation at /mnt/julia/packages/Enzyme/di3zM/src/compiler.jl:4637
unknown function (ip: 0x7f4af61f3885)
_jl_invoke at /buildworker/worker/package_linux64/build/src/gf.c:2358 [inlined]
ijl_apply_generic at /buildworker/worker/package_linux64/build/src/gf.c:2540
#s565#115 at /mnt/julia/packages/Enzyme/di3zM/src/compiler.jl:4697 [inlined]
#s565#115 at ./none:0
_jl_invoke at /buildworker/worker/package_linux64/build/src/gf.c:2358 [inlined]
ijl_apply_generic at /buildworker/worker/package_linux64/build/src/gf.c:2540
GeneratedFunctionStub at ./boot.jl:582
_jl_invoke at /buildworker/worker/package_linux64/build/src/gf.c:2358 [inlined]
ijl_apply_generic at /buildworker/worker/package_linux64/build/src/gf.c:2540
jl_apply at /buildworker/worker/package_linux64/build/src/julia.h:1831 [inlined]
jl_call_staged at /buildworker/worker/package_linux64/build/src/method.c:520
ijl_code_for_staged at /buildworker/worker/package_linux64/build/src/method.c:571
get_staged at ./compiler/utilities.jl:114
retrieve_code_info at ./compiler/utilities.jl:126 [inlined]
InferenceState at ./compiler/inferencestate.jl:280
typeinf_edge at ./compiler/typeinfer.jl:867
abstract_call_method at ./compiler/abstractinterpretation.jl:632
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:156
abstract_call_known at ./compiler/abstractinterpretation.jl:1666
abstract_call at ./compiler/abstractinterpretation.jl:1724
abstract_call at ./compiler/abstractinterpretation.jl:1703
abstract_eval_statement at ./compiler/abstractinterpretation.jl:1845
typeinf_local at ./compiler/abstractinterpretation.jl:2310
typeinf_nocycle at ./compiler/abstractinterpretation.jl:2406
_typeinf at ./compiler/typeinfer.jl:230
typeinf at ./compiler/typeinfer.jl:213
typeinf_edge at ./compiler/typeinfer.jl:876
abstract_call_method at ./compiler/abstractinterpretation.jl:632
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:156
abstract_call_known at ./compiler/abstractinterpretation.jl:1666
abstract_call at ./compiler/abstractinterpretation.jl:1724
abstract_call at ./compiler/abstractinterpretation.jl:1703
abstract_eval_statement at ./compiler/abstractinterpretation.jl:1845
typeinf_local at ./compiler/abstractinterpretation.jl:2310
typeinf_nocycle at ./compiler/abstractinterpretation.jl:2406
_typeinf at ./compiler/typeinfer.jl:230
typeinf at ./compiler/typeinfer.jl:213
typeinf_edge at ./compiler/typeinfer.jl:876
abstract_call_method at ./compiler/abstractinterpretation.jl:632
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:156
abstract_call_known at ./compiler/abstractinterpretation.jl:1666
abstract_call at ./compiler/abstractinterpretation.jl:1724
abstract_call at ./compiler/abstractinterpretation.jl:1703
abstract_eval_statement at ./compiler/abstractinterpretation.jl:1845
typeinf_local at ./compiler/abstractinterpretation.jl:2310
typeinf_nocycle at ./compiler/abstractinterpretation.jl:2406
_typeinf at ./compiler/typeinfer.jl:230
typeinf at ./compiler/typeinfer.jl:213
typeinf_ext at ./compiler/typeinfer.jl:957
typeinf_ext_toplevel at ./compiler/typeinfer.jl:990
typeinf_ext_toplevel at ./compiler/typeinfer.jl:986
jfptr_typeinf_ext_toplevel_16088.clone_1 at /mnt/softwares/julia-nightly/lib/julia/sys.so (unknown line)
_jl_invoke at /buildworker/worker/package_linux64/build/src/gf.c:2358 [inlined]
ijl_apply_generic at /buildworker/worker/package_linux64/build/src/gf.c:2540
jl_apply at /buildworker/worker/package_linux64/build/src/julia.h:1831 [inlined]
jl_type_infer at /buildworker/worker/package_linux64/build/src/gf.c:319
jl_generate_fptr_impl at /buildworker/worker/package_linux64/build/src/jitlayers.cpp:314
jl_compile_method_internal at /buildworker/worker/package_linux64/build/src/gf.c:2072
_jl_invoke at /buildworker/worker/package_linux64/build/src/gf.c:2350 [inlined]
ijl_apply_generic at /buildworker/worker/package_linux64/build/src/gf.c:2540
autodiff at /mnt/julia/packages/Enzyme/di3zM/src/Enzyme.jl:285
_jl_invoke at /buildworker/worker/package_linux64/build/src/gf.c:2358 [inlined]
ijl_apply_generic at /buildworker/worker/package_linux64/build/src/gf.c:2540
jl_apply at /buildworker/worker/package_linux64/build/src/julia.h:1831 [inlined]
do_apply at /buildworker/worker/package_linux64/build/src/builtins.c:725
autodiff at /mnt/julia/packages/Enzyme/di3zM/src/Enzyme.jl:319
_jl_invoke at /buildworker/worker/package_linux64/build/src/gf.c:2358 [inlined]
ijl_apply_generic at /buildworker/worker/package_linux64/build/src/gf.c:2540
jl_apply at /buildworker/worker/package_linux64/build/src/julia.h:1831 [inlined]
do_apply at /buildworker/worker/package_linux64/build/src/builtins.c:725
autodiff at /mnt/julia/packages/Enzyme/di3zM/src/Enzyme.jl:412
_jl_invoke at /buildworker/worker/package_linux64/build/src/gf.c:2358 [inlined]
ijl_apply_generic at /buildworker/worker/package_linux64/build/src/gf.c:2540
jl_apply at /buildworker/worker/package_linux64/build/src/julia.h:1831 [inlined]
do_call at /buildworker/worker/package_linux64/build/src/interpreter.c:126
eval_value at /buildworker/worker/package_linux64/build/src/interpreter.c:215
eval_stmt_value at /buildworker/worker/package_linux64/build/src/interpreter.c:166 [inlined]
eval_body at /buildworker/worker/package_linux64/build/src/interpreter.c:612
jl_interpret_toplevel_thunk at /buildworker/worker/package_linux64/build/src/interpreter.c:750
jl_toplevel_eval_flex at /buildworker/worker/package_linux64/build/src/toplevel.c:906
jl_toplevel_eval_flex at /buildworker/worker/package_linux64/build/src/toplevel.c:850
jl_toplevel_eval_flex at /buildworker/worker/package_linux64/build/src/toplevel.c:850
eval_body at /buildworker/worker/package_linux64/build/src/interpreter.c:556
eval_body at /buildworker/worker/package_linux64/build/src/interpreter.c:522
jl_interpret_toplevel_thunk at /buildworker/worker/package_linux64/build/src/interpreter.c:750
jl_toplevel_eval_flex at /buildworker/worker/package_linux64/build/src/toplevel.c:906
ijl_toplevel_eval_in at /buildworker/worker/package_linux64/build/src/toplevel.c:965
eval at ./boot.jl:368 [inlined]
eval_user_input at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.8/REPL/src/REPL.jl:151
repl_backend_loop at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.8/REPL/src/REPL.jl:247
start_repl_backend at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.8/REPL/src/REPL.jl:232
#run_repl#47 at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.8/REPL/src/REPL.jl:369
run_repl at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.8/REPL/src/REPL.jl:356
jfptr_run_repl_63590.clone_1 at /mnt/softwares/julia-nightly/lib/julia/sys.so (unknown line)
_jl_invoke at /buildworker/worker/package_linux64/build/src/gf.c:2358 [inlined]
ijl_apply_generic at /buildworker/worker/package_linux64/build/src/gf.c:2540
#964 at ./client.jl:419
jfptr_YY.964_53574.clone_1 at /mnt/softwares/julia-nightly/lib/julia/sys.so (unknown line)
_jl_invoke at /buildworker/worker/package_linux64/build/src/gf.c:2358 [inlined]
ijl_apply_generic at /buildworker/worker/package_linux64/build/src/gf.c:2540
jl_apply at /buildworker/worker/package_linux64/build/src/julia.h:1831 [inlined]
jl_f__call_latest at /buildworker/worker/package_linux64/build/src/builtins.c:769
#invokelatest#2 at ./essentials.jl:729 [inlined]
invokelatest at ./essentials.jl:727 [inlined]
run_main_repl at ./client.jl:404
exec_options at ./client.jl:318
_start at ./client.jl:522
jfptr__start_58493.clone_1 at /mnt/softwares/julia-nightly/lib/julia/sys.so (unknown line)
_jl_invoke at /buildworker/worker/package_linux64/build/src/gf.c:2358 [inlined]
ijl_apply_generic at /buildworker/worker/package_linux64/build/src/gf.c:2540
jl_apply at /buildworker/worker/package_linux64/build/src/julia.h:1831 [inlined]
true_main at /buildworker/worker/package_linux64/build/src/jlapi.c:567
jl_repl_entrypoint at /buildworker/worker/package_linux64/build/src/jlapi.c:711
main at /buildworker/worker/package_linux64/build/cli/loader_exe.c:59
unknown function (ip: 0x7f4af79d9d8f)
__libc_start_main at /lib/x86_64-linux-gnu/libc.so.6 (unknown line)
_start at /mnt/softwares/julia-nightly/bin/julia (unknown line)
Allocations: 133145644 (Pool: 133050263; Big: 95381); GC: 61

avik-pal avatar Jul 06 '22 04:07 avik-pal

cc @vchuravy in case you want some neural network test cases.

avik-pal avatar Jul 06 '22 04:07 avik-pal

cc: @wsmoses

vchuravy avatar Jul 06 '22 10:07 vchuravy

Oh that’s cool last I checked dense wasn’t working. As for the batchnorm error, would you mind minimizing that (eg ideally using pure Julia and no packages), and posting an issue in enzyme.jl and we can try to get that fixed

wsmoses avatar Jul 06 '22 12:07 wsmoses

@avik-pal bumping the above so we can start working on a fix

wsmoses avatar Jul 24 '22 00:07 wsmoses

I haven't been able to point to the segfault reason but here is something that doesn't work:

using Enzyme

x = randn(10, 2)
dx = zero(x)

@inline @generated _safe_vec(x::T) where {T} = hasmethod(vec, (T,)) ? :(vec(x)) : :x

lfn(x) = sum(_safe_vec(x))

Enzyme.autodiff(lfn, Duplicated(x, dx))
ERROR: Enzyme: Not yet implemented reverse for jl_array_reshape

Same for

using Enzyme

@inline @inbounds function _get_reshape_dims(sx::NTuple{N, <:Int},
                                             ly::Int)::typeof(sx) where {N}
    if ly == sx[N - 1]
        return ntuple(i -> i == N - 1 ? ly : 1, N)
    elseif N > 2 && ly == sx[N - 1] * sx[N - 2]
        return ntuple(i -> i == (N - 1) || i == (N - 2) ? sx[i] : 1, N)
    else
        error("Invalid Dimensions")
    end
end

@inline _reshape_into_proper_shape(x::Nothing, y)::Nothing = x
@inline _reshape_into_proper_shape(x, y)::typeof(y) = reshape(x,
                                                              _get_reshape_dims(size(y),
                                                                                length(x)))

x = randn(10)
dx = zero(x)
y = randn(10, 1)
lfn(x) = sum(_reshape_into_proper_shape(x, y))

Enzyme.autodiff(lfn, Duplicated(x, dx))

avik-pal avatar Jul 26 '22 05:07 avik-pal

@wsmoses I am guessing this is known or should I open an issue?

avik-pal avatar Jul 26 '22 05:07 avik-pal

Can you try main since https://github.com/EnzymeAD/Enzyme.jl/pull/380 just landed?

vchuravy avatar Jul 26 '22 12:07 vchuravy

With Enzyme main these errors are fixed. The segfaults still persists.

avik-pal avatar Aug 01 '22 03:08 avik-pal

👍 let me know when you've managed to reduce the BatchNorm bug into pure julia and I can have a look.

wsmoses avatar Aug 01 '22 04:08 wsmoses

BatchNorm has stopped segfaulting on Julia 1.8 + Enzyme https://github.com/EnzymeAD/Enzyme.jl/pull/446

warning: Linking two modules of different target triples: 'bcloader' is 'x86_64-unknown-linux-gnu' whereas 'text' is 'x86_64-linux-gnu'

warning: Linking two modules of different target triples: 'bcloader' is 'x86_64-unknown-linux-gnu' whereas 'text' is 'x86_64-linux-gnu'

warning: Linking two modules of different target triples: 'bcloader' is 'x86_64-unknown-linux-gnu' whereas 'text' is 'x86_64-linux-gnu'

┌ Warning: Using fallback BLAS replacements, performance may be degraded
└ @ Enzyme.Compiler /mnt/julia/packages/GPUCompiler/07qaN/src/utils.jl:35
ERROR: Enzyme: Not yet implemented augmented forward for jl_f__apply_iterate

avik-pal avatar Sep 08 '22 16:09 avik-pal

@avik-pal by the way we added some backtrace info for these types of errors if you want to see where your type instability is coming from (and perhaps fix it)

wsmoses avatar Sep 26 '22 19:09 wsmoses

rep:   %14 = bitcast {} addrspace(10)* %13 to [3 x {} addrspace(10)*] addrspace(10)*, !enzyme_caststack !6 prev:   %6 = alloca [3 x {} addrspace(10)*], align 8 inst:   call fastcc void @julia__normalization_3909([3 x {} addrspace(10)*]* noalias nocapture noundef nonnull sret([3 x {} addrspace(10)*]) align 8 dereferenceable(24) %6, {} addrspace(10)* noundef nonnull align 16 dereferenceable(40) %2, {} addrspace(10)* nonnull align 16 dereferenceable(40) %21, {} addrspace(10)* nonnull align 16 dereferenceable(40) %23, {} addrspace(10)* nonnull align 16 dereferenceable(40) %17, {} addrspace(10)* nonnull align 16 dereferenceable(40) %19, {} addrspace(10)* noundef nonnull align 16 dereferenceable(40) %32, float %33, float %34) #64, !dbg !141
Illegal address space propagation
UNREACHABLE executed at /workspace/srcdir/Enzyme/enzyme/Enzyme/FunctionUtils.cpp:366!

signal (6): Aborted
in expression starting at REPL[8]:1
pthread_kill at /lib/x86_64-linux-gnu/libc.so.6 (unknown line)
raise at /lib/x86_64-linux-gnu/libc.so.6 (unknown line)
abort at /lib/x86_64-linux-gnu/libc.so.6 (unknown line)
_ZN4llvm25llvm_unreachable_internalEPKcS1_j at /mnt/julia/juliaup/julia-1.8.1+0.x64/bin/../lib/julia/libLLVM-13jl.so (unknown line)
RecursivelyReplaceAddressSpace at /workspace/srcdir/Enzyme/enzyme/Enzyme/FunctionUtils.cpp:366
UpgradeAllocasToMallocs at /workspace/srcdir/Enzyme/enzyme/Enzyme/FunctionUtils.cpp:437
preprocessForClone at /workspace/srcdir/Enzyme/enzyme/Enzyme/FunctionUtils.cpp:1694
CloneFunctionWithReturns at /workspace/srcdir/Enzyme/enzyme/Enzyme/FunctionUtils.cpp:1982
CreateFromClone at /workspace/srcdir/Enzyme/enzyme/Enzyme/GradientUtils.cpp:3399
CreateAugmentedPrimal at /workspace/srcdir/Enzyme/enzyme/Enzyme/EnzymeLogic.cpp:2058
visitCallInst at /workspace/srcdir/Enzyme/enzyme/Enzyme/AdjointGenerator.h:11768
delegateCallInst at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/InstVisitor.h:302 [inlined]
visitCall at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/Instruction.def:209 [inlined]
visit at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/Instruction.def:209
visit at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/InstVisitor.h:112 [inlined]
CreatePrimalAndGradient at /workspace/srcdir/Enzyme/enzyme/Enzyme/EnzymeLogic.cpp:3916
EnzymeCreatePrimalAndGradient at /workspace/srcdir/Enzyme/enzyme/Enzyme/CApi.cpp:473
EnzymeCreatePrimalAndGradient at /mnt/julia/packages/Enzyme/wJg1H/src/api.jl:118
enzyme! at /mnt/julia/packages/Enzyme/wJg1H/src/compiler.jl:4289
unknown function (ip: 0x7ff998fdfc84)
_jl_invoke at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:2367 [inlined]
ijl_apply_generic at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:2549
#codegen#112 at /mnt/julia/packages/Enzyme/wJg1H/src/compiler.jl:5275
codegen##kw at /mnt/julia/packages/Enzyme/wJg1H/src/compiler.jl:4945 [inlined]
_thunk at /mnt/julia/packages/Enzyme/wJg1H/src/compiler.jl:5748 [inlined]
_thunk at /mnt/julia/packages/Enzyme/wJg1H/src/compiler.jl:5742
unknown function (ip: 0x7ff998f751dd)
_jl_invoke at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:2367 [inlined]
ijl_apply_generic at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:2549
cached_compilation at /mnt/julia/packages/Enzyme/wJg1H/src/compiler.jl:5786
unknown function (ip: 0x7ff998f50f03)
_jl_invoke at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:2367 [inlined]
ijl_apply_generic at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:2549
#s814#139 at /mnt/julia/packages/Enzyme/wJg1H/src/compiler.jl:5846 [inlined]
#s814#139 at ./none:0
_jl_invoke at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:2367 [inlined]
ijl_apply_generic at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:2549
GeneratedFunctionStub at ./boot.jl:582
_jl_invoke at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:2367 [inlined]
ijl_apply_generic at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:2549
jl_apply at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/julia.h:1838 [inlined]
jl_call_staged at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/method.c:520
ijl_code_for_staged at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/method.c:571
get_staged at ./compiler/utilities.jl:114
retrieve_code_info at ./compiler/utilities.jl:126 [inlined]
InferenceState at ./compiler/inferencestate.jl:284
typeinf_edge at ./compiler/typeinfer.jl:868
abstract_call_method at ./compiler/abstractinterpretation.jl:641
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:153
abstract_call_known at ./compiler/abstractinterpretation.jl:1696
abstract_call at ./compiler/abstractinterpretation.jl:1766
abstract_call at ./compiler/abstractinterpretation.jl:1733
abstract_eval_statement at ./compiler/abstractinterpretation.jl:1890
typeinf_local at ./compiler/abstractinterpretation.jl:2366
typeinf_nocycle at ./compiler/abstractinterpretation.jl:2462
_typeinf at ./compiler/typeinfer.jl:230
typeinf at ./compiler/typeinfer.jl:213
typeinf_edge at ./compiler/typeinfer.jl:877
abstract_call_method at ./compiler/abstractinterpretation.jl:641
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:153
abstract_call_known at ./compiler/abstractinterpretation.jl:1696
abstract_call at ./compiler/abstractinterpretation.jl:1766
abstract_call at ./compiler/abstractinterpretation.jl:1733
abstract_eval_statement at ./compiler/abstractinterpretation.jl:1890
typeinf_local at ./compiler/abstractinterpretation.jl:2366
typeinf_nocycle at ./compiler/abstractinterpretation.jl:2462
_typeinf at ./compiler/typeinfer.jl:230
typeinf at ./compiler/typeinfer.jl:213
typeinf_edge at ./compiler/typeinfer.jl:877
abstract_call_method at ./compiler/abstractinterpretation.jl:641
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:153
abstract_call_known at ./compiler/abstractinterpretation.jl:1696
abstract_call at ./compiler/abstractinterpretation.jl:1766
abstract_call at ./compiler/abstractinterpretation.jl:1733
abstract_eval_statement at ./compiler/abstractinterpretation.jl:1890
typeinf_local at ./compiler/abstractinterpretation.jl:2366
typeinf_nocycle at ./compiler/abstractinterpretation.jl:2462
_typeinf at ./compiler/typeinfer.jl:230
typeinf at ./compiler/typeinfer.jl:213
typeinf_ext at ./compiler/typeinfer.jl:967
typeinf_ext_toplevel at ./compiler/typeinfer.jl:1000
typeinf_ext_toplevel at ./compiler/typeinfer.jl:996
jfptr_typeinf_ext_toplevel_18807.clone_1 at /mnt/julia/juliaup/julia-1.8.1+0.x64/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:2367 [inlined]
ijl_apply_generic at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:2549
jl_apply at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/julia.h:1838 [inlined]
jl_type_infer at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:319
jl_generate_fptr_impl at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/jitlayers.cpp:319
jl_compile_method_internal at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:2081 [inlined]
jl_compile_method_internal at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:2025
_jl_invoke at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:2359 [inlined]
ijl_apply_generic at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:2549
autodiff at /mnt/julia/packages/Enzyme/wJg1H/src/Enzyme.jl:296
_jl_invoke at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:2367 [inlined]
ijl_apply_generic at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:2549
jl_apply at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/julia.h:1838 [inlined]
do_apply at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/builtins.c:730
autodiff at /mnt/julia/packages/Enzyme/wJg1H/src/Enzyme.jl:330
_jl_invoke at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:2367 [inlined]
ijl_apply_generic at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:2549
jl_apply at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/julia.h:1838 [inlined]
do_apply at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/builtins.c:730
autodiff at /mnt/julia/packages/Enzyme/wJg1H/src/Enzyme.jl:423
_jl_invoke at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:2367 [inlined]
ijl_apply_generic at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:2549
jl_apply at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/julia.h:1838 [inlined]
do_call at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/interpreter.c:126
eval_value at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/interpreter.c:215
eval_stmt_value at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/interpreter.c:166 [inlined]
eval_body at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/interpreter.c:612
jl_interpret_toplevel_thunk at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/interpreter.c:750
jl_toplevel_eval_flex at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/toplevel.c:906
jl_toplevel_eval_flex at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/toplevel.c:850
jl_toplevel_eval_flex at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/toplevel.c:850
eval_body at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/interpreter.c:556
eval_body at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/interpreter.c:522
jl_interpret_toplevel_thunk at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/interpreter.c:750
jl_toplevel_eval_flex at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/toplevel.c:906
ijl_toplevel_eval_in at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/toplevel.c:965
eval at ./boot.jl:368 [inlined]
eval_user_input at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/usr/share/julia/stdlib/v1.8/REPL/src/REPL.jl:151
repl_backend_loop at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/usr/share/julia/stdlib/v1.8/REPL/src/REPL.jl:247
start_repl_backend at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/usr/share/julia/stdlib/v1.8/REPL/src/REPL.jl:232
#run_repl#47 at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/usr/share/julia/stdlib/v1.8/REPL/src/REPL.jl:369
run_repl at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/usr/share/julia/stdlib/v1.8/REPL/src/REPL.jl:355
jfptr_run_repl_66557.clone_1 at /mnt/julia/juliaup/julia-1.8.1+0.x64/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:2367 [inlined]
ijl_apply_generic at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:2549
#967 at ./client.jl:419
jfptr_YY.967_49700.clone_1 at /mnt/julia/juliaup/julia-1.8.1+0.x64/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:2367 [inlined]
ijl_apply_generic at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:2549
jl_apply at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/julia.h:1838 [inlined]
jl_f__call_latest at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/builtins.c:774
#invokelatest#2 at ./essentials.jl:729 [inlined]
invokelatest at ./essentials.jl:726 [inlined]
run_main_repl at ./client.jl:404
exec_options at ./client.jl:318
_start at ./client.jl:522
jfptr__start_61720.clone_1 at /mnt/julia/juliaup/julia-1.8.1+0.x64/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:2367 [inlined]
ijl_apply_generic at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:2549
jl_apply at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/julia.h:1838 [inlined]
true_main at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/jlapi.c:575
jl_repl_entrypoint at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/jlapi.c:719
main at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/cli/loader_exe.c:59
unknown function (ip: 0x7ffa36fd0d8f)
__libc_start_main at /lib/x86_64-linux-gnu/libc.so.6 (unknown line)
unknown function (ip: 0x401098)
Allocations: 128727060 (Pool: 128641444; Big: 85616); GC: 56

avik-pal avatar Sep 28 '22 17:09 avik-pal

using Enzyme

function _update_normalization_statistics(x::AbstractArray{<:Real, N},
                                          running_mean::AbstractArray{<:Real, N},
                                          running_var::AbstractArray{<:Real, N},
                                          batchmean::AbstractArray{<:Real, N},
                                          batchvar::AbstractArray{<:Real, N},
                                          momentum::Real, reduce_dims) where {N}
    sx = size(x)
    m = (eltype(x))(prod(sx[reduce_dims]))::eltype(x)
    if last(reduce_dims) != N
        batchmean = mean(batchmean; dims=N)
        batchvar = mean(batchvar; dims=N)
    end
    running_mean = @. (1 - momentum) * running_mean + momentum * batchmean
    running_var = @. (1 - momentum) * running_var + momentum * batchvar * (m / (m - one(m)))
    return (running_mean, running_var)
end

function loss_function(x, xmean, xvar, scale, bias, rm, rv)
    return sum(sum, _update_normalization_statistics(x, rm, rv, xmean, xvar, 0.1, [2]))
end

function _setup()
  x = rand(10, 2)
  xmean = rand(10, 1)
  xvar = abs2.(rand(10, 1))
  scale = rand(10, 1)
  bias = rand(10, 1)
  rm = rand(10, 1)
  rv = abs2.(rand(10, 1))

  ps = (x, xmean, xvar, scale, bias, rm, rv)
  return ps, zero.(ps)
end

(x, xmean, xvar, scale, bias, rm, rv), (dx, dxmean, dxvar, dscale, dbias, drm, drv) = _setup()

loss_function(x, xmean, xvar, scale, bias, rm, rv)

Enzyme.autodiff(loss_function, Duplicated(x, dx), Duplicated(xmean, dxmean), Duplicated(xvar, dxvar), Duplicated(scale, dscale), Duplicated(bias, dbias), Duplicated(rm, drm), Duplicated(rv, drv))

Different assert fail

julia: /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/InstrTypes.h:1445: void llvm::CallBase::setCalledFunction(llvm::FunctionType*, llvm::Value*): Assertion `getType() == FTy->getReturnType()' failed.

signal (6): Aborted
in expression starting at REPL[19]:1
pthread_kill at /lib/x86_64-linux-gnu/libc.so.6 (unknown line)
raise at /lib/x86_64-linux-gnu/libc.so.6 (unknown line)
abort at /lib/x86_64-linux-gnu/libc.so.6 (unknown line)
unknown function (ip: 0x7fd5df1a271a)
__assert_fail at /lib/x86_64-linux-gnu/libc.so.6 (unknown line)
setCalledFunction at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/InstrTypes.h:1445 [inlined]
setCalledFunction at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/InstrTypes.h:1430 [inlined]
CreateAugmentedPrimal at /workspace/srcdir/Enzyme/enzyme/Enzyme/EnzymeLogic.cpp:2705
visitCallInst at /workspace/srcdir/Enzyme/enzyme/Enzyme/AdjointGenerator.h:11768
delegateCallInst at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/InstVisitor.h:302 [inlined]
visitCall at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/Instruction.def:209 [inlined]
visit at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/Instruction.def:209
visit at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/InstVisitor.h:112 [inlined]
CreateAugmentedPrimal at /workspace/srcdir/Enzyme/enzyme/Enzyme/EnzymeLogic.cpp:2213
visitCallInst at /workspace/srcdir/Enzyme/enzyme/Enzyme/AdjointGenerator.h:11768
delegateCallInst at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/InstVisitor.h:302 [inlined]
visitCall at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/Instruction.def:209 [inlined]
visit at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/Instruction.def:209
visit at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/InstVisitor.h:112 [inlined]
CreateAugmentedPrimal at /workspace/srcdir/Enzyme/enzyme/Enzyme/EnzymeLogic.cpp:2213
visitCallInst at /workspace/srcdir/Enzyme/enzyme/Enzyme/AdjointGenerator.h:11768
delegateCallInst at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/InstVisitor.h:302 [inlined]
visitCall at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/Instruction.def:209 [inlined]
visit at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/Instruction.def:209
visit at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/InstVisitor.h:112 [inlined]
CreateAugmentedPrimal at /workspace/srcdir/Enzyme/enzyme/Enzyme/EnzymeLogic.cpp:2213
visitCallInst at /workspace/srcdir/Enzyme/enzyme/Enzyme/AdjointGenerator.h:11768
delegateCallInst at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/InstVisitor.h:302 [inlined]
visitCall at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/Instruction.def:209 [inlined]
visit at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/Instruction.def:209
visit at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/InstVisitor.h:112 [inlined]
CreateAugmentedPrimal at /workspace/srcdir/Enzyme/enzyme/Enzyme/EnzymeLogic.cpp:2213
visitCallInst at /workspace/srcdir/Enzyme/enzyme/Enzyme/AdjointGenerator.h:11768
delegateCallInst at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/InstVisitor.h:302 [inlined]
visitCall at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/Instruction.def:209 [inlined]
visit at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/Instruction.def:209
visit at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/InstVisitor.h:112 [inlined]
CreatePrimalAndGradient at /workspace/srcdir/Enzyme/enzyme/Enzyme/EnzymeLogic.cpp:3916
EnzymeCreatePrimalAndGradient at /workspace/srcdir/Enzyme/enzyme/Enzyme/CApi.cpp:473
EnzymeCreatePrimalAndGradient at /mnt/julia/packages/Enzyme/wJg1H/src/api.jl:118
enzyme! at /mnt/julia/packages/Enzyme/wJg1H/src/compiler.jl:4289
unknown function (ip: 0x7fd549857cd4)
_jl_invoke at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:2367 [inlined]
ijl_apply_generic at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:2549
#codegen#112 at /mnt/julia/packages/Enzyme/wJg1H/src/compiler.jl:5275
codegen##kw at /mnt/julia/packages/Enzyme/wJg1H/src/compiler.jl:4945 [inlined]
_thunk at /mnt/julia/packages/Enzyme/wJg1H/src/compiler.jl:5748 [inlined]
_thunk at /mnt/julia/packages/Enzyme/wJg1H/src/compiler.jl:5742
unknown function (ip: 0x7fd5497f378d)
_jl_invoke at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:2367 [inlined]
ijl_apply_generic at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:2549
cached_compilation at /mnt/julia/packages/Enzyme/wJg1H/src/compiler.jl:5786
unknown function (ip: 0x7fd5497d00d3)
_jl_invoke at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:2367 [inlined]
ijl_apply_generic at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:2549
#s814#139 at /mnt/julia/packages/Enzyme/wJg1H/src/compiler.jl:5846 [inlined]
#s814#139 at ./none:0
_jl_invoke at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:2367 [inlined]
ijl_apply_generic at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:2549
GeneratedFunctionStub at ./boot.jl:582
_jl_invoke at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:2367 [inlined]
ijl_apply_generic at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:2549
jl_apply at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/julia.h:1838 [inlined]
jl_call_staged at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/method.c:520
ijl_code_for_staged at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/method.c:571
get_staged at ./compiler/utilities.jl:114
retrieve_code_info at ./compiler/utilities.jl:126 [inlined]
InferenceState at ./compiler/inferencestate.jl:284
typeinf_edge at ./compiler/typeinfer.jl:868
abstract_call_method at ./compiler/abstractinterpretation.jl:641
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:153
abstract_call_known at ./compiler/abstractinterpretation.jl:1696
abstract_call at ./compiler/abstractinterpretation.jl:1766
abstract_call at ./compiler/abstractinterpretation.jl:1733
abstract_eval_statement at ./compiler/abstractinterpretation.jl:1890
typeinf_local at ./compiler/abstractinterpretation.jl:2366
typeinf_nocycle at ./compiler/abstractinterpretation.jl:2462
_typeinf at ./compiler/typeinfer.jl:230
typeinf at ./compiler/typeinfer.jl:213
typeinf_edge at ./compiler/typeinfer.jl:877
abstract_call_method at ./compiler/abstractinterpretation.jl:641
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:153
abstract_call_known at ./compiler/abstractinterpretation.jl:1696
abstract_call at ./compiler/abstractinterpretation.jl:1766
abstract_call at ./compiler/abstractinterpretation.jl:1733
abstract_eval_statement at ./compiler/abstractinterpretation.jl:1890
typeinf_local at ./compiler/abstractinterpretation.jl:2366
typeinf_nocycle at ./compiler/abstractinterpretation.jl:2462
_typeinf at ./compiler/typeinfer.jl:230
typeinf at ./compiler/typeinfer.jl:213
typeinf_edge at ./compiler/typeinfer.jl:877
abstract_call_method at ./compiler/abstractinterpretation.jl:641
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:153
abstract_call_known at ./compiler/abstractinterpretation.jl:1696
abstract_call at ./compiler/abstractinterpretation.jl:1766
abstract_call at ./compiler/abstractinterpretation.jl:1733
abstract_eval_statement at ./compiler/abstractinterpretation.jl:1890
typeinf_local at ./compiler/abstractinterpretation.jl:2366
typeinf_nocycle at ./compiler/abstractinterpretation.jl:2462
_typeinf at ./compiler/typeinfer.jl:230
typeinf at ./compiler/typeinfer.jl:213
typeinf_ext at ./compiler/typeinfer.jl:967
typeinf_ext_toplevel at ./compiler/typeinfer.jl:1000
typeinf_ext_toplevel at ./compiler/typeinfer.jl:996
jfptr_typeinf_ext_toplevel_18807.clone_1 at /mnt/julia/juliaup/julia-1.8.1+0.x64/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:2367 [inlined]
ijl_apply_generic at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:2549
jl_apply at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/julia.h:1838 [inlined]
jl_type_infer at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:319
jl_generate_fptr_impl at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/jitlayers.cpp:319
jl_compile_method_internal at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:2081 [inlined]
jl_compile_method_internal at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:2025
_jl_invoke at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:2359 [inlined]
ijl_apply_generic at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:2549
autodiff at /mnt/julia/packages/Enzyme/wJg1H/src/Enzyme.jl:296
_jl_invoke at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:2367 [inlined]
ijl_apply_generic at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:2549
jl_apply at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/julia.h:1838 [inlined]
do_apply at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/builtins.c:730
autodiff at /mnt/julia/packages/Enzyme/wJg1H/src/Enzyme.jl:330
_jl_invoke at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:2367 [inlined]
ijl_apply_generic at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:2549
jl_apply at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/julia.h:1838 [inlined]
do_apply at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/builtins.c:730
autodiff at /mnt/julia/packages/Enzyme/wJg1H/src/Enzyme.jl:423
_jl_invoke at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:2367 [inlined]
ijl_apply_generic at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:2549
jl_apply at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/julia.h:1838 [inlined]
do_call at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/interpreter.c:126
eval_value at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/interpreter.c:215
eval_stmt_value at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/interpreter.c:166 [inlined]
eval_body at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/interpreter.c:612
jl_interpret_toplevel_thunk at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/interpreter.c:750
jl_toplevel_eval_flex at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/toplevel.c:906
jl_toplevel_eval_flex at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/toplevel.c:850
jl_toplevel_eval_flex at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/toplevel.c:850
ijl_toplevel_eval_in at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/toplevel.c:965
eval at ./boot.jl:368 [inlined]
eval_user_input at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/usr/share/julia/stdlib/v1.8/REPL/src/REPL.jl:151
repl_backend_loop at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/usr/share/julia/stdlib/v1.8/REPL/src/REPL.jl:247
start_repl_backend at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/usr/share/julia/stdlib/v1.8/REPL/src/REPL.jl:232
#run_repl#47 at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/usr/share/julia/stdlib/v1.8/REPL/src/REPL.jl:369
run_repl at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/usr/share/julia/stdlib/v1.8/REPL/src/REPL.jl:355
jfptr_run_repl_66557.clone_1 at /mnt/julia/juliaup/julia-1.8.1+0.x64/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:2367 [inlined]
ijl_apply_generic at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:2549
#967 at ./client.jl:419
jfptr_YY.967_49700.clone_1 at /mnt/julia/juliaup/julia-1.8.1+0.x64/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:2367 [inlined]
ijl_apply_generic at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:2549
jl_apply at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/julia.h:1838 [inlined]
jl_f__call_latest at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/builtins.c:774
#invokelatest#2 at ./essentials.jl:729 [inlined]
invokelatest at ./essentials.jl:726 [inlined]
run_main_repl at ./client.jl:404
exec_options at ./client.jl:318
_start at ./client.jl:522
jfptr__start_61720.clone_1 at /mnt/julia/juliaup/julia-1.8.1+0.x64/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:2367 [inlined]
ijl_apply_generic at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/gf.c:2549
jl_apply at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/julia.h:1838 [inlined]
true_main at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/jlapi.c:575
jl_repl_entrypoint at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/src/jlapi.c:719
main at /cache/build/default-amdci5-0/julialang/julia-release-1-dot-8/cli/loader_exe.c:59
unknown function (ip: 0x7fd5df1a3d8f)
__libc_start_main at /lib/x86_64-linux-gnu/libc.so.6 (unknown line)
unknown function (ip: 0x401098)
Allocations: 84566216 (Pool: 84501708; Big: 64508); GC: 45

avik-pal avatar Sep 28 '22 17:09 avik-pal

On 1.8.1 I get:

julia> Enzyme.autodiff(loss_function, Duplicated(x, dx), Duplicated(xmean, dxmean), Duplicated(xvar, dxvar), Duplicated(scale, dscale), Duplicated(bias, dbias), Duplicated(rm, drm), Duplicated(rv, drv))
ERROR: Duplicated Returns not yet handled
Stacktrace:
 [1] autodiff(::Enzyme.ReverseMode, ::typeof(loss_function), ::Type{Duplicated{Any}}, ::Duplicated{Matrix{Float64}}, ::Vararg{Duplicated{Matrix{Float64}}})
   @ Enzyme ~/git/Enzyme.jl/src/Enzyme.jl:294
 [2] autodiff(::Enzyme.ReverseMode, ::typeof(loss_function), ::Duplicated{Matrix{Float64}}, ::Duplicated{Matrix{Float64}}, ::Vararg{Duplicated{Matrix{Float64}}})
   @ Enzyme ~/git/Enzyme.jl/src/Enzyme.jl:330
 [3] autodiff(::typeof(loss_function), ::Duplicated{Matrix{Float64}}, ::Duplicated{Matrix{Float64}}, ::Duplicated{Matrix{Float64}}, ::Vararg{Duplicated{Matrix{Float64}}})
   @ Enzyme ~/git/Enzyme.jl/src/Enzyme.jl:423
 [4] top-level scope
   @ REPL[13]:1

This implies a type instable return or any/vector/etc return, rather than a float.

Also https://github.com/EnzymeAD/Enzyme.jl/pull/462 alongside a jll bump should fix the latter error, but we still need a minimal reproducer for the first one.

wsmoses avatar Sep 28 '22 18:09 wsmoses

using Lux, Random, Enzyme
rng = Random.default_rng()

function loss_function(model, x, ps, st)
    return sum(Lux.apply(model, x, ps, st)[1])
end

model = Chain(Dense(2 => 4), BatchNorm(4), Dense(4 => 2))
ps, st = Lux.setup(rng, model)
x = randn(rng, Float32, 2, 1)

dps = Lux.fmap(zero, ps)
Enzyme.autodiff(loss_function, Const(model), Const(x), Duplicated(ps, dps), Const(st))
println(dps)

has started to work (no error) but leads to NaN gradients for anything before Batchnorm:

(layer_1 = (weight = Float32[NaN NaN; NaN NaN; NaN NaN; NaN NaN], bias = Float32[NaN; NaN; NaN; NaN;;]), layer_2 = (scale = Float32[0.0, 0.0, 0.0, 0.0], bias = Float32[0.7332505, -0.56484723, -0.85575366, 0.11406183]), layer_3 = (weight = Float32[0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0], bias = Float32[1.0; 1.0;;]))

From Zygote

(layer_1 = (weight = Float32[0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0], bias = Float32[0.0; 0.0; 0.0; 0.0;;]), layer_2 = (scale = Float32[0.0, -0.0, -0.0, 0.0], bias = Float32[0.7332505, -0.56484723, -0.85575366, 0.11406183]), layer_3 = (weight = Float32[0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0], bias = Fill(1.0f0, 2, 1)))

avik-pal avatar Mar 20 '23 15:03 avik-pal

@avik-pal could you make a minimal reproducer of the batchnorm error for us to debug/fix?

wsmoses avatar Apr 07 '23 04:04 wsmoses

@avik-pal the batchnorm error is a multiply by infinity error in LuxLib.

In particular: https://github.com/LuxDL/LuxLib.jl/blob/da558718590fc904dbe2620fdd98df4619d3669d/src/impl/normalization.jl#L16

    m = eltype(x)(prod(Base.Fix1(size, x), reduce_dims))
    if last(reduce_dims) != N
        batchmean = mean(batchmean; dims=N)
        batchvar = mean(batchvar; dims=N)
    end
    running_mean = @. (1 - momentum) * running_mean + momentum * batchmean
    running_var = @. (1 - momentum) * running_var + momentum * batchvar * (m / (m - one(m)))
x=Float32[0.79961216; 0.87499446; 0.15016815; -0.53201634;;];

julia> m = eltype(x)(prod(Base.Fix1(size, x), 2))
1.0f0

julia> (m / (m - one(m)))
Inf32

wsmoses avatar May 14 '23 18:05 wsmoses

Oh also in this case batchvar is 0's

Float32[0.0; 0.0; 0.0; 0.0;;]

Thus the original code is computes nan....

wsmoses avatar May 14 '23 18:05 wsmoses

You are right; tracking statistics/training shouldn't work with batchsize = 1 for BatchNorm in general.

avik-pal avatar May 15 '23 14:05 avik-pal

Hi folks! Awesome work! I see that trying to instantiate EnzymeVJP still raises an ArgumentError stating that support for the Enzyme backend has not been implemented yet. What are the plans on this end? :)

classner avatar May 23 '23 08:05 classner

I hope to get to that eventually (PRs for that are always welcome, though). Currently, I am spending time (albeit kind of too little) testing the backends LuxLib & NNlib with Enzyme first.

avik-pal avatar May 23 '23 15:05 avik-pal

Closing in favor of https://github.com/LuxDL/Lyme.jl

avik-pal avatar Sep 24 '23 01:09 avik-pal