Enzyme.jl icon indicating copy to clipboard operation
Enzyme.jl copied to clipboard

excessive compile times and failures with GPU

Open ExpandingMan opened this issue 10 months ago • 3 comments

Minimal Lux examples on GPU are still impractical due to excessive compile times, and occasional failures.

Here is a minimal example:

using LinearAlgebra, Random, Statistics, Optimisers
using CUDA
using Lux, LuxCUDA, MLDataDevices
import ProgressMeter as PM
import Enzyme, Zygote

using Lux.MLDataDevices: AbstractDevice

Enzyme.Compiler.VERBOSE_ERRORS[] = true

function makedata(rng::AbstractRNG)
    X = reshape(collect(range(-2.0f0, 2.0f0, 128)), (1, 128))
    y = evalpoly.(X, ((0, -2, 1),)) .+ randn(rng, Float32, (1, 128)) .* 0.1f0
    (X, y)
end

function main(dev::AbstractDevice, rng=Random.Xoshiro(999),
              model=Lux.Chain(Lux.Dense(1=>16, gelu), Lux.Dense(16=>1)),
              (X, y)=makedata(rng) |> dev;
              nepochs=300,
              opt=Adam(0.01f0),
              backend=AutoEnzyme(),
             )
    (θ, ψ) = Lux.setup(rng, model) |> dev

    s = Lux.Training.TrainState(model, θ, ψ, opt)
    pm = PM.Progress(nepochs)
    for j ∈ 1:nepochs
        (∂s, ℓ, stats, s) = Lux.Training.single_train_step!(
            backend, MSELoss(),
            (X, y), s,
        )
        PM.next!(pm)
    end
    PM.finish!(pm)

    (yhat, _) = Lux.apply(model, X, θ, ψ)

    (yhat, y)
end

On the CPU I get

◖◗ @time main(cpu_device());
Progress: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| Time: 0:00:00
 43.469745 seconds (481.86 M allocations: 20.523 GiB, 5.22% gc time, 99.98% compilation time)

This is a bit worryingly slow since it is such a small example, but depending on how it scales it isn't necessarily prohibitive.

However, on GPU it tries to compile for about 10 to 15 minutes and then gives the following error

◖◗ @time main(gpu_device());
ERROR: Enzyme compilation failed due to an internal error.
 Please open an issue with the code to reproduce and full error log on github.com/EnzymeAD/Enzyme.jl
 To toggle more information for debugging (needed for bug reports), set Enzyme.Compiler.VERBOSE_ERRORS[] = true (default false)
Current scope:
define internal fastcc void @julia_nonblocking_synchronize_187572({} addrspace(10)* noundef nonnull align 8 dereferenceable(40) %0) unnamed_addr #197 !dbg !8001 {
top:
  %1 = alloca [3 x [2 x {} addrspace(10)*]], align 8
  %pgcstack = call {}*** @julia.get_pgcstack()
  %ptls_field6 = getelementptr inbounds {}**, {}*** %pgcstack, i64 2
  %2 = bitcast {}*** %ptls_field6 to i64***
  %ptls_load78 = load i64**, i64*** %2, align 8, !tbaa !381
  %3 = getelementptr inbounds i64*, i64** %ptls_load78, i64 2
  %safepoint = load i64*, i64** %3, align 8, !tbaa !385
  fence syncscope("singlethread") seq_cst
  call void @julia.safepoint(i64* %safepoint), !dbg !8002
  fence syncscope("singlethread") seq_cst
  %4 = call nonnull {}* @julia.pointer_from_objref({} addrspace(11)* addrspacecast ({}* inttoptr (i64 124712895855264 to {}*) to {} addrspace(11)*)) #413, !dbg !8003
  %ptr.i = bitcast {}* %4 to i32*, !dbg !8007
  %rv.i = atomicrmw add i32* %ptr.i, i32 1 acq_rel, align 4, !dbg !8007
  %5 = and i32 %rv.i, 3, !dbg !8010
  %.not = icmp eq i32 %5, 0, !dbg !8018
  %narrow = select i1 %.not, i32 4, i32 %5, !dbg !8020
  %6 = zext i32 %narrow to i64, !dbg !8020
  %7 = load i64, i64* inttoptr (i64 124712895855408 to i64*), align 16, !dbg !8022, !tbaa !633, !alias.scope !636, !noalias !637
  %8 = add nsw i64 %6, -1, !dbg !8035
  %.not9 = icmp ult i64 %8, %7, !dbg !8038
  br i1 %.not9, label %L40, label %L49, !dbg !8032

L40:                                              ; preds = %top
  %9 = load {} addrspace(10)**, {} addrspace(10)*** inttoptr (i64 124712895855392 to {} addrspace(10)***), align 32, !dbg !8040, !tbaa !642, !alias.scope !636, !noalias !637
  %10 = load {} addrspace(10)*, {} addrspace(10)** inttoptr (i64 124712895855400 to {} addrspace(10)**), align 8, !dbg !8040, !tbaa !642, !alias.scope !636, !noalias !637, !dereferenceable_or_null !494, !align !503
  %11 = call {} addrspace(10)* addrspace(13)* @julia.gc_loaded({} addrspace(10)* %10, {} addrspace(10)** %9), !dbg !8043
  %12 = bitcast {} addrspace(10)* addrspace(13)* %11 to [3 x [2 x {} addrspace(10)*]] addrspace(13)*, !dbg !8043
  %13 = getelementptr inbounds [3 x [2 x {} addrspace(10)*]], [3 x [2 x {} addrspace(10)*]] addrspace(13)* %12, i64 %8, i64 0, i64 0, !dbg !8043
  %14 = load {} addrspace(10)*, {} addrspace(10)* addrspace(13)* %13, align 8, !dbg !8043, !tbaa !4485, !alias.scope !429, !noalias !430
  %.not24 = icmp eq {} addrspace(10)* %14, null, !dbg !8043
  br i1 %.not24, label %L49, label %pass3, !dbg !8034

L49:                                              ; preds = %L40, %top
  call fastcc void @julia_create_synchronization_worker_189388(i64 signext %6), !dbg !8045
  %.pre = load {} addrspace(10)**, {} addrspace(10)*** inttoptr (i64 124712895855392 to {} addrspace(10)***), align 32, !dbg !8046, !tbaa !642, !alias.scope !636, !noalias !637
  %.pre25 = load {} addrspace(10)*, {} addrspace(10)** inttoptr (i64 124712895855400 to {} addrspace(10)**), align 8, !dbg !8046, !tbaa !642, !alias.scope !636, !noalias !637
  %.pre26 = call {} addrspace(10)* addrspace(13)* @julia.gc_loaded({} addrspace(10)* %.pre25, {} addrspace(10)** %.pre), !dbg !8046
  %.pre27 = bitcast {} addrspace(10)* addrspace(13)* %.pre26 to [3 x [2 x {} addrspace(10)*]] addrspace(13)*, !dbg !8046
  %.unpack.elt.phi.trans.insert = getelementptr inbounds [3 x [2 x {} addrspace(10)*]], [3 x [2 x {} addrspace(10)*]] addrspace(13)* %.pre27, i64 %8, i64 0, i64 0
  %.unpack.unpack.pre = load {} addrspace(10)*, {} addrspace(10)* addrspace(13)* %.unpack.elt.phi.trans.insert, align 8, !dbg !8046, !tbaa !4485, !alias.scope !429, !noalias !430
  %.not23 = icmp eq {} addrspace(10)* %.unpack.unpack.pre, null, !dbg !8046
  br i1 %.not23, label %fail2, label %pass3, !dbg !8046

L75:                                              ; preds = %pass3
  call fastcc void @julia_throw_api_error_187593(i32 zeroext %18) #414, !dbg !8049
  unreachable, !dbg !8049

L77:                                              ; preds = %pass3
  ret void, !dbg !8050

fail2:                                            ; preds = %L49
  %15 = load {}*, {}** @jl_undefref_exception, align 8, !dbg !8046, !tbaa !385, !alias.scope !437, !noalias !438, !nonnull !380
  %16 = addrspacecast {}* %15 to {} addrspace(12)*, !dbg !8046
  call void @ijl_throw({} addrspace(12)* %16), !dbg !8046
  unreachable, !dbg !8046

pass3:                                            ; preds = %L40, %L49
  %nodecayed..pre-phi2834 = phi {} addrspace(10)*
  %nodecayedoff..pre-phi2834 = phi i64
  %.pre-phi2834 = phi [3 x [2 x {} addrspace(10)*]] addrspace(13)* [ %.pre27, %L49 ], [ %12, %L40 ]
  %.unpack.unpack33 = phi {} addrspace(10)* [ %.unpack.unpack.pre, %L49 ], [ %14, %L40 ]
  %.unpack.elt14 = getelementptr inbounds [3 x [2 x {} addrspace(10)*]], [3 x [2 x {} addrspace(10)*]] addrspace(13)* %.pre-phi2834, i64 %8, i64 0, i64 1, !dbg !8046
  %.unpack.unpack15 = load {} addrspace(10)*, {} addrspace(10)* addrspace(13)* %.unpack.elt14, align 8, !dbg !8046, !tbaa !4485, !alias.scope !429, !noalias !430
  %.unpack11.elt = getelementptr inbounds [3 x [2 x {} addrspace(10)*]], [3 x [2 x {} addrspace(10)*]] addrspace(13)* %.pre-phi2834, i64 %8, i64 1, i64 0, !dbg !8046
  %.unpack11.unpack = load {} addrspace(10)*, {} addrspace(10)* addrspace(13)* %.unpack11.elt, align 8, !dbg !8046, !tbaa !4485, !alias.scope !429, !noalias !430
  %.unpack11.elt17 = getelementptr inbounds [3 x [2 x {} addrspace(10)*]], [3 x [2 x {} addrspace(10)*]] addrspace(13)* %.pre-phi2834, i64 %8, i64 1, i64 1, !dbg !8046
  %.unpack11.unpack18 = load {} addrspace(10)*, {} addrspace(10)* addrspace(13)* %.unpack11.elt17, align 8, !dbg !8046, !tbaa !4485, !alias.scope !429, !noalias !430
  %.unpack13.elt = getelementptr inbounds [3 x [2 x {} addrspace(10)*]], [3 x [2 x {} addrspace(10)*]] addrspace(13)* %.pre-phi2834, i64 %8, i64 2, i64 0, !dbg !8046
  %.unpack13.unpack = load {} addrspace(10)*, {} addrspace(10)* addrspace(13)* %.unpack13.elt, align 8, !dbg !8046, !tbaa !4485, !alias.scope !429, !noalias !430
  %.unpack13.elt20 = getelementptr inbounds [3 x [2 x {} addrspace(10)*]], [3 x [2 x {} addrspace(10)*]] addrspace(13)* %.pre-phi2834, i64 %8, i64 2, i64 1, !dbg !8046
  %.unpack13.unpack21 = load {} addrspace(10)*, {} addrspace(10)* addrspace(13)* %.unpack13.elt20, align 8, !dbg !8046, !tbaa !4485, !alias.scope !429, !noalias !430
  %.fca.0.0.gep = getelementptr inbounds [3 x [2 x {} addrspace(10)*]], [3 x [2 x {} addrspace(10)*]]* %1, i64 0, i64 0, i64 0, !dbg !8051
  store {} addrspace(10)* %.unpack.unpack33, {} addrspace(10)** %.fca.0.0.gep, align 8, !dbg !8051, !noalias !566
  %.fca.0.1.gep = getelementptr inbounds [3 x [2 x {} addrspace(10)*]], [3 x [2 x {} addrspace(10)*]]* %1, i64 0, i64 0, i64 1, !dbg !8051
  store {} addrspace(10)* %.unpack.unpack15, {} addrspace(10)** %.fca.0.1.gep, align 8, !dbg !8051, !noalias !566
  %.fca.1.0.gep = getelementptr inbounds [3 x [2 x {} addrspace(10)*]], [3 x [2 x {} addrspace(10)*]]* %1, i64 0, i64 1, i64 0, !dbg !8051
  store {} addrspace(10)* %.unpack11.unpack, {} addrspace(10)** %.fca.1.0.gep, align 8, !dbg !8051, !noalias !566
  %.fca.1.1.gep = getelementptr inbounds [3 x [2 x {} addrspace(10)*]], [3 x [2 x {} addrspace(10)*]]* %1, i64 0, i64 1, i64 1, !dbg !8051
  store {} addrspace(10)* %.unpack11.unpack18, {} addrspace(10)** %.fca.1.1.gep, align 8, !dbg !8051, !noalias !566
  %.fca.2.0.gep = getelementptr inbounds [3 x [2 x {} addrspace(10)*]], [3 x [2 x {} addrspace(10)*]]* %1, i64 0, i64 2, i64 0, !dbg !8051
  store {} addrspace(10)* %.unpack13.unpack, {} addrspace(10)** %.fca.2.0.gep, align 8, !dbg !8051, !noalias !566
  %.fca.2.1.gep = getelementptr inbounds [3 x [2 x {} addrspace(10)*]], [3 x [2 x {} addrspace(10)*]]* %1, i64 0, i64 2, i64 1, !dbg !8051
  store {} addrspace(10)* %.unpack13.unpack21, {} addrspace(10)** %.fca.2.1.gep, align 8, !dbg !8051, !noalias !566
  %17 = addrspacecast [3 x [2 x {} addrspace(10)*]]* %1 to [3 x [2 x {} addrspace(10)*]] addrspace(11)*, !dbg !8051
  %18 = call fastcc i32 @julia_put__189354([3 x [2 x {} addrspace(10)*]] addrspace(11)* nocapture noundef nonnull readonly align 8 dereferenceable(48) %17, {} addrspace(10)* noundef nonnull align 8 dereferenceable(40) %0), !dbg !8051
  %19 = icmp eq i32 %18, 0, !dbg !8052
  br i1 %19, label %L77, label %L75, !dbg !8057
}

Could not analyze garbage collection behavior of
 inst:   %.pre-phi2834 = phi [3 x [2 x {} addrspace(10)*]] addrspace(13)* [ %.pre27, %L49 ], [ %12, %L40 ]
 v0:   %.pre27 = bitcast {} addrspace(10)* addrspace(13)* %.pre26 to [3 x [2 x {} addrspace(10)*]] addrspace(13)*, !dbg !461
 v: {} addrspace(10)*** inttoptr (i64 124712895855392 to {} addrspace(10)***)
 offset: i64 0
 hasload: true


Stacktrace:
 [1] #synchronize#1003
   @ ~/.julia/packages/CUDA/1kIOw/lib/cudadrv/synchronization.jl:200
 [2] multiple call sites
   @ unknown:0

Stacktrace:
  [1] (::Enzyme.Compiler.var"#getparent#69"{…})(b::LLVM.IRBuilder, v::LLVM.Value, offset::LLVM.Value, hasload::Bool, phicache::Dict{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/llvm/transforms.jl:969
  [2] (::Enzyme.Compiler.var"#getparent#69"{…})(b::LLVM.IRBuilder, v::LLVM.Value, offset::LLVM.Value, hasload::Bool, phicache::Dict{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/llvm/transforms.jl:653
  [3] (::Enzyme.Compiler.var"#getparent#69"{…})(b::LLVM.IRBuilder, v::LLVM.Value, offset::LLVM.Value, hasload::Bool, phicache::Dict{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/llvm/transforms.jl:682
  [4] (::Enzyme.Compiler.var"#getparent#69"{…})(b::LLVM.IRBuilder, v::LLVM.Value, offset::LLVM.Value, hasload::Bool, phicache::Dict{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/llvm/transforms.jl:818
  [5] nodecayed_phis!(mod::LLVM.Module)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/llvm/transforms.jl:976
  [6] optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/compiler/optimize.jl:582
  [7] nested_codegen!(mode::Enzyme.API.CDerivativeMode, mod::LLVM.Module, funcspec::Core.MethodInstance, world::UInt64)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:401
  [8] enzyme_custom_common_rev(forward::Bool, B::LLVM.IRBuilder, orig::LLVM.CallInst, gutils::Enzyme.Compiler.GradientUtils, normalR::Ptr{…}, shadowR::Ptr{…}, tape::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/rules/customrules.jl:960
  [9] enzyme_custom_augfwd
    @ ~/.julia/packages/Enzyme/R6sE8/src/rules/customrules.jl:1503 [inlined]
 [10] enzyme_custom_augfwd_cfunc(B::Ptr{…}, OrigCI::Ptr{…}, gutils::Ptr{…}, normalR::Ptr{…}, shadowR::Ptr{…}, tapeR::Ptr{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/rules/llvmrules.jl:18
 [11] EnzymeCreatePrimalAndGradient(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{…}, TA::Enzyme.TypeAnalysis, returnValue::Bool, dretUsed::Bool, mode::Enzyme.API.CDerivativeMode, runtimeActivity::Bool, width::Int64, additionalArg::Ptr{…}, forceAnonymousTape::Bool, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{…}, augmented::Ptr{…}, atomicAdd::Bool)
    @ Enzyme.API ~/.julia/packages/Enzyme/R6sE8/src/api.jl:268
 [12] enzyme!(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::NTuple{…} where N, returnPrimal::Bool, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:1706
 [13] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:4550
 [14] codegen
    @ ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:3353 [inlined]
 [15] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:5410
 [16] _thunk
    @ ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:5410 [inlined]
 [17] cached_compilation
    @ ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:5462 [inlined]
 [18] thunkbase(mi::Core.MethodInstance, World::UInt64, FA::Type{…}, A::Type{…}, TT::Type, Mode::Enzyme.API.CDerivativeMode, width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, edges::Vector{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:5573
 [19] thunk_generator(world::UInt64, source::LineNumberNode, FA::Type, A::Type, TT::Type, Mode::Enzyme.API.CDerivativeMode, Width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, self::Any, fakeworld::Any, fa::Type, a::Type, tt::Type, mode::Type, width::Type, modifiedbetween::Type, returnprimal::Type, shadowinit::Type, abi::Type, erriffuncwritten::Type, runtimeactivity::Type)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:5758
 [20] autodiff
    @ ~/.julia/packages/Enzyme/R6sE8/src/Enzyme.jl:485 [inlined]
 [21] compute_gradients_impl(ad::AutoEnzyme{…}, obj_fn::GenericLossFunction{…}, data::Tuple{…}, ts::Lux.Training.TrainState{…})
    @ LuxEnzymeExt ~/.julia/packages/Lux/DHtyL/ext/LuxEnzymeExt/training.jl:8
 [22] compute_gradients
    @ ~/.julia/packages/Lux/DHtyL/src/helpers/training.jl:200 [inlined]
 [23] single_train_step_impl!
    @ ~/.julia/packages/Lux/DHtyL/src/helpers/training.jl:320 [inlined]
 [24] #single_train_step!#6
    @ ~/.julia/packages/Lux/DHtyL/src/helpers/training.jl:288 [inlined]
 [25] single_train_step!(backend::AutoEnzyme{…}, obj_fn::GenericLossFunction{…}, data::Tuple{…}, ts::Lux.Training.TrainState{…})
    @ Lux.Training ~/.julia/packages/Lux/DHtyL/src/helpers/training.jl:284
 [26] main(dev::CUDADevice{…}, rng::Xoshiro, model::Chain{…}, ::Tuple{…}; nepochs::Int64, opt::Adam, backend::AutoEnzyme{…})
    @ Main ~/src/lux_enzyme_test.jl:29
 [27] main(dev::CUDADevice{Nothing}, rng::Xoshiro, model::Chain{@NamedTuple{…}, Nothing}, ::Tuple{CuArray{…}, CuArray{…}})
    @ Main ~/src/lux_enzyme_test.jl:17
 [28] macro expansion
    @ ./timing.jl:581 [inlined]
 [29] top-level scope
    @ ./REPL[2]:1
Some type information was truncated. Use `show(err)` to see complete types.

ExpandingMan avatar Jan 28 '25 00:01 ExpandingMan

cpu side, all remaining time is kept in the Julia base compiler:

╎   +11 28880 …ase/compiler/abstractinterpretation.jl:2282; abstract_call(interp::Core.Compiler.NativeInterpreter, arginfo

wsmoses avatar Sep 25 '25 16:09 wsmoses

For the CPU the majority of time spent is coming from lacking precompilation statements.

--trace-compile result https://gist.github.com/vchuravy/d2e1a70dbdb4bf76b454867c3b6fd0c2

vchuravy avatar Sep 25 '25 19:09 vchuravy

there's still more to go but @ExpandingMan how does this look now (we've had a lot of nontrivial compile time improvements recently)

wsmoses avatar Oct 27 '25 06:10 wsmoses