excessive compile times and failures with GPU
Minimal Lux examples on GPU are still impractical due to excessive compile times, and occasional failures.
Here is a minimal example:
using LinearAlgebra, Random, Statistics, Optimisers
using CUDA
using Lux, LuxCUDA, MLDataDevices
import ProgressMeter as PM
import Enzyme, Zygote
using Lux.MLDataDevices: AbstractDevice
Enzyme.Compiler.VERBOSE_ERRORS[] = true
function makedata(rng::AbstractRNG)
X = reshape(collect(range(-2.0f0, 2.0f0, 128)), (1, 128))
y = evalpoly.(X, ((0, -2, 1),)) .+ randn(rng, Float32, (1, 128)) .* 0.1f0
(X, y)
end
function main(dev::AbstractDevice, rng=Random.Xoshiro(999),
model=Lux.Chain(Lux.Dense(1=>16, gelu), Lux.Dense(16=>1)),
(X, y)=makedata(rng) |> dev;
nepochs=300,
opt=Adam(0.01f0),
backend=AutoEnzyme(),
)
(θ, ψ) = Lux.setup(rng, model) |> dev
s = Lux.Training.TrainState(model, θ, ψ, opt)
pm = PM.Progress(nepochs)
for j ∈ 1:nepochs
(∂s, ℓ, stats, s) = Lux.Training.single_train_step!(
backend, MSELoss(),
(X, y), s,
)
PM.next!(pm)
end
PM.finish!(pm)
(yhat, _) = Lux.apply(model, X, θ, ψ)
(yhat, y)
end
On the CPU I get
◖◗ @time main(cpu_device());
Progress: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| Time: 0:00:00
43.469745 seconds (481.86 M allocations: 20.523 GiB, 5.22% gc time, 99.98% compilation time)
This is a bit worryingly slow since it is such a small example, but depending on how it scales it isn't necessarily prohibitive.
However, on GPU it tries to compile for about 10 to 15 minutes and then gives the following error
◖◗ @time main(gpu_device());
ERROR: Enzyme compilation failed due to an internal error.
Please open an issue with the code to reproduce and full error log on github.com/EnzymeAD/Enzyme.jl
To toggle more information for debugging (needed for bug reports), set Enzyme.Compiler.VERBOSE_ERRORS[] = true (default false)
Current scope:
define internal fastcc void @julia_nonblocking_synchronize_187572({} addrspace(10)* noundef nonnull align 8 dereferenceable(40) %0) unnamed_addr #197 !dbg !8001 {
top:
%1 = alloca [3 x [2 x {} addrspace(10)*]], align 8
%pgcstack = call {}*** @julia.get_pgcstack()
%ptls_field6 = getelementptr inbounds {}**, {}*** %pgcstack, i64 2
%2 = bitcast {}*** %ptls_field6 to i64***
%ptls_load78 = load i64**, i64*** %2, align 8, !tbaa !381
%3 = getelementptr inbounds i64*, i64** %ptls_load78, i64 2
%safepoint = load i64*, i64** %3, align 8, !tbaa !385
fence syncscope("singlethread") seq_cst
call void @julia.safepoint(i64* %safepoint), !dbg !8002
fence syncscope("singlethread") seq_cst
%4 = call nonnull {}* @julia.pointer_from_objref({} addrspace(11)* addrspacecast ({}* inttoptr (i64 124712895855264 to {}*) to {} addrspace(11)*)) #413, !dbg !8003
%ptr.i = bitcast {}* %4 to i32*, !dbg !8007
%rv.i = atomicrmw add i32* %ptr.i, i32 1 acq_rel, align 4, !dbg !8007
%5 = and i32 %rv.i, 3, !dbg !8010
%.not = icmp eq i32 %5, 0, !dbg !8018
%narrow = select i1 %.not, i32 4, i32 %5, !dbg !8020
%6 = zext i32 %narrow to i64, !dbg !8020
%7 = load i64, i64* inttoptr (i64 124712895855408 to i64*), align 16, !dbg !8022, !tbaa !633, !alias.scope !636, !noalias !637
%8 = add nsw i64 %6, -1, !dbg !8035
%.not9 = icmp ult i64 %8, %7, !dbg !8038
br i1 %.not9, label %L40, label %L49, !dbg !8032
L40: ; preds = %top
%9 = load {} addrspace(10)**, {} addrspace(10)*** inttoptr (i64 124712895855392 to {} addrspace(10)***), align 32, !dbg !8040, !tbaa !642, !alias.scope !636, !noalias !637
%10 = load {} addrspace(10)*, {} addrspace(10)** inttoptr (i64 124712895855400 to {} addrspace(10)**), align 8, !dbg !8040, !tbaa !642, !alias.scope !636, !noalias !637, !dereferenceable_or_null !494, !align !503
%11 = call {} addrspace(10)* addrspace(13)* @julia.gc_loaded({} addrspace(10)* %10, {} addrspace(10)** %9), !dbg !8043
%12 = bitcast {} addrspace(10)* addrspace(13)* %11 to [3 x [2 x {} addrspace(10)*]] addrspace(13)*, !dbg !8043
%13 = getelementptr inbounds [3 x [2 x {} addrspace(10)*]], [3 x [2 x {} addrspace(10)*]] addrspace(13)* %12, i64 %8, i64 0, i64 0, !dbg !8043
%14 = load {} addrspace(10)*, {} addrspace(10)* addrspace(13)* %13, align 8, !dbg !8043, !tbaa !4485, !alias.scope !429, !noalias !430
%.not24 = icmp eq {} addrspace(10)* %14, null, !dbg !8043
br i1 %.not24, label %L49, label %pass3, !dbg !8034
L49: ; preds = %L40, %top
call fastcc void @julia_create_synchronization_worker_189388(i64 signext %6), !dbg !8045
%.pre = load {} addrspace(10)**, {} addrspace(10)*** inttoptr (i64 124712895855392 to {} addrspace(10)***), align 32, !dbg !8046, !tbaa !642, !alias.scope !636, !noalias !637
%.pre25 = load {} addrspace(10)*, {} addrspace(10)** inttoptr (i64 124712895855400 to {} addrspace(10)**), align 8, !dbg !8046, !tbaa !642, !alias.scope !636, !noalias !637
%.pre26 = call {} addrspace(10)* addrspace(13)* @julia.gc_loaded({} addrspace(10)* %.pre25, {} addrspace(10)** %.pre), !dbg !8046
%.pre27 = bitcast {} addrspace(10)* addrspace(13)* %.pre26 to [3 x [2 x {} addrspace(10)*]] addrspace(13)*, !dbg !8046
%.unpack.elt.phi.trans.insert = getelementptr inbounds [3 x [2 x {} addrspace(10)*]], [3 x [2 x {} addrspace(10)*]] addrspace(13)* %.pre27, i64 %8, i64 0, i64 0
%.unpack.unpack.pre = load {} addrspace(10)*, {} addrspace(10)* addrspace(13)* %.unpack.elt.phi.trans.insert, align 8, !dbg !8046, !tbaa !4485, !alias.scope !429, !noalias !430
%.not23 = icmp eq {} addrspace(10)* %.unpack.unpack.pre, null, !dbg !8046
br i1 %.not23, label %fail2, label %pass3, !dbg !8046
L75: ; preds = %pass3
call fastcc void @julia_throw_api_error_187593(i32 zeroext %18) #414, !dbg !8049
unreachable, !dbg !8049
L77: ; preds = %pass3
ret void, !dbg !8050
fail2: ; preds = %L49
%15 = load {}*, {}** @jl_undefref_exception, align 8, !dbg !8046, !tbaa !385, !alias.scope !437, !noalias !438, !nonnull !380
%16 = addrspacecast {}* %15 to {} addrspace(12)*, !dbg !8046
call void @ijl_throw({} addrspace(12)* %16), !dbg !8046
unreachable, !dbg !8046
pass3: ; preds = %L40, %L49
%nodecayed..pre-phi2834 = phi {} addrspace(10)*
%nodecayedoff..pre-phi2834 = phi i64
%.pre-phi2834 = phi [3 x [2 x {} addrspace(10)*]] addrspace(13)* [ %.pre27, %L49 ], [ %12, %L40 ]
%.unpack.unpack33 = phi {} addrspace(10)* [ %.unpack.unpack.pre, %L49 ], [ %14, %L40 ]
%.unpack.elt14 = getelementptr inbounds [3 x [2 x {} addrspace(10)*]], [3 x [2 x {} addrspace(10)*]] addrspace(13)* %.pre-phi2834, i64 %8, i64 0, i64 1, !dbg !8046
%.unpack.unpack15 = load {} addrspace(10)*, {} addrspace(10)* addrspace(13)* %.unpack.elt14, align 8, !dbg !8046, !tbaa !4485, !alias.scope !429, !noalias !430
%.unpack11.elt = getelementptr inbounds [3 x [2 x {} addrspace(10)*]], [3 x [2 x {} addrspace(10)*]] addrspace(13)* %.pre-phi2834, i64 %8, i64 1, i64 0, !dbg !8046
%.unpack11.unpack = load {} addrspace(10)*, {} addrspace(10)* addrspace(13)* %.unpack11.elt, align 8, !dbg !8046, !tbaa !4485, !alias.scope !429, !noalias !430
%.unpack11.elt17 = getelementptr inbounds [3 x [2 x {} addrspace(10)*]], [3 x [2 x {} addrspace(10)*]] addrspace(13)* %.pre-phi2834, i64 %8, i64 1, i64 1, !dbg !8046
%.unpack11.unpack18 = load {} addrspace(10)*, {} addrspace(10)* addrspace(13)* %.unpack11.elt17, align 8, !dbg !8046, !tbaa !4485, !alias.scope !429, !noalias !430
%.unpack13.elt = getelementptr inbounds [3 x [2 x {} addrspace(10)*]], [3 x [2 x {} addrspace(10)*]] addrspace(13)* %.pre-phi2834, i64 %8, i64 2, i64 0, !dbg !8046
%.unpack13.unpack = load {} addrspace(10)*, {} addrspace(10)* addrspace(13)* %.unpack13.elt, align 8, !dbg !8046, !tbaa !4485, !alias.scope !429, !noalias !430
%.unpack13.elt20 = getelementptr inbounds [3 x [2 x {} addrspace(10)*]], [3 x [2 x {} addrspace(10)*]] addrspace(13)* %.pre-phi2834, i64 %8, i64 2, i64 1, !dbg !8046
%.unpack13.unpack21 = load {} addrspace(10)*, {} addrspace(10)* addrspace(13)* %.unpack13.elt20, align 8, !dbg !8046, !tbaa !4485, !alias.scope !429, !noalias !430
%.fca.0.0.gep = getelementptr inbounds [3 x [2 x {} addrspace(10)*]], [3 x [2 x {} addrspace(10)*]]* %1, i64 0, i64 0, i64 0, !dbg !8051
store {} addrspace(10)* %.unpack.unpack33, {} addrspace(10)** %.fca.0.0.gep, align 8, !dbg !8051, !noalias !566
%.fca.0.1.gep = getelementptr inbounds [3 x [2 x {} addrspace(10)*]], [3 x [2 x {} addrspace(10)*]]* %1, i64 0, i64 0, i64 1, !dbg !8051
store {} addrspace(10)* %.unpack.unpack15, {} addrspace(10)** %.fca.0.1.gep, align 8, !dbg !8051, !noalias !566
%.fca.1.0.gep = getelementptr inbounds [3 x [2 x {} addrspace(10)*]], [3 x [2 x {} addrspace(10)*]]* %1, i64 0, i64 1, i64 0, !dbg !8051
store {} addrspace(10)* %.unpack11.unpack, {} addrspace(10)** %.fca.1.0.gep, align 8, !dbg !8051, !noalias !566
%.fca.1.1.gep = getelementptr inbounds [3 x [2 x {} addrspace(10)*]], [3 x [2 x {} addrspace(10)*]]* %1, i64 0, i64 1, i64 1, !dbg !8051
store {} addrspace(10)* %.unpack11.unpack18, {} addrspace(10)** %.fca.1.1.gep, align 8, !dbg !8051, !noalias !566
%.fca.2.0.gep = getelementptr inbounds [3 x [2 x {} addrspace(10)*]], [3 x [2 x {} addrspace(10)*]]* %1, i64 0, i64 2, i64 0, !dbg !8051
store {} addrspace(10)* %.unpack13.unpack, {} addrspace(10)** %.fca.2.0.gep, align 8, !dbg !8051, !noalias !566
%.fca.2.1.gep = getelementptr inbounds [3 x [2 x {} addrspace(10)*]], [3 x [2 x {} addrspace(10)*]]* %1, i64 0, i64 2, i64 1, !dbg !8051
store {} addrspace(10)* %.unpack13.unpack21, {} addrspace(10)** %.fca.2.1.gep, align 8, !dbg !8051, !noalias !566
%17 = addrspacecast [3 x [2 x {} addrspace(10)*]]* %1 to [3 x [2 x {} addrspace(10)*]] addrspace(11)*, !dbg !8051
%18 = call fastcc i32 @julia_put__189354([3 x [2 x {} addrspace(10)*]] addrspace(11)* nocapture noundef nonnull readonly align 8 dereferenceable(48) %17, {} addrspace(10)* noundef nonnull align 8 dereferenceable(40) %0), !dbg !8051
%19 = icmp eq i32 %18, 0, !dbg !8052
br i1 %19, label %L77, label %L75, !dbg !8057
}
Could not analyze garbage collection behavior of
inst: %.pre-phi2834 = phi [3 x [2 x {} addrspace(10)*]] addrspace(13)* [ %.pre27, %L49 ], [ %12, %L40 ]
v0: %.pre27 = bitcast {} addrspace(10)* addrspace(13)* %.pre26 to [3 x [2 x {} addrspace(10)*]] addrspace(13)*, !dbg !461
v: {} addrspace(10)*** inttoptr (i64 124712895855392 to {} addrspace(10)***)
offset: i64 0
hasload: true
Stacktrace:
[1] #synchronize#1003
@ ~/.julia/packages/CUDA/1kIOw/lib/cudadrv/synchronization.jl:200
[2] multiple call sites
@ unknown:0
Stacktrace:
[1] (::Enzyme.Compiler.var"#getparent#69"{…})(b::LLVM.IRBuilder, v::LLVM.Value, offset::LLVM.Value, hasload::Bool, phicache::Dict{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/llvm/transforms.jl:969
[2] (::Enzyme.Compiler.var"#getparent#69"{…})(b::LLVM.IRBuilder, v::LLVM.Value, offset::LLVM.Value, hasload::Bool, phicache::Dict{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/llvm/transforms.jl:653
[3] (::Enzyme.Compiler.var"#getparent#69"{…})(b::LLVM.IRBuilder, v::LLVM.Value, offset::LLVM.Value, hasload::Bool, phicache::Dict{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/llvm/transforms.jl:682
[4] (::Enzyme.Compiler.var"#getparent#69"{…})(b::LLVM.IRBuilder, v::LLVM.Value, offset::LLVM.Value, hasload::Bool, phicache::Dict{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/llvm/transforms.jl:818
[5] nodecayed_phis!(mod::LLVM.Module)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/llvm/transforms.jl:976
[6] optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/compiler/optimize.jl:582
[7] nested_codegen!(mode::Enzyme.API.CDerivativeMode, mod::LLVM.Module, funcspec::Core.MethodInstance, world::UInt64)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:401
[8] enzyme_custom_common_rev(forward::Bool, B::LLVM.IRBuilder, orig::LLVM.CallInst, gutils::Enzyme.Compiler.GradientUtils, normalR::Ptr{…}, shadowR::Ptr{…}, tape::Nothing)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/rules/customrules.jl:960
[9] enzyme_custom_augfwd
@ ~/.julia/packages/Enzyme/R6sE8/src/rules/customrules.jl:1503 [inlined]
[10] enzyme_custom_augfwd_cfunc(B::Ptr{…}, OrigCI::Ptr{…}, gutils::Ptr{…}, normalR::Ptr{…}, shadowR::Ptr{…}, tapeR::Ptr{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/rules/llvmrules.jl:18
[11] EnzymeCreatePrimalAndGradient(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{…}, TA::Enzyme.TypeAnalysis, returnValue::Bool, dretUsed::Bool, mode::Enzyme.API.CDerivativeMode, runtimeActivity::Bool, width::Int64, additionalArg::Ptr{…}, forceAnonymousTape::Bool, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{…}, augmented::Ptr{…}, atomicAdd::Bool)
@ Enzyme.API ~/.julia/packages/Enzyme/R6sE8/src/api.jl:268
[12] enzyme!(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::NTuple{…} where N, returnPrimal::Bool, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:1706
[13] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:4550
[14] codegen
@ ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:3353 [inlined]
[15] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:5410
[16] _thunk
@ ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:5410 [inlined]
[17] cached_compilation
@ ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:5462 [inlined]
[18] thunkbase(mi::Core.MethodInstance, World::UInt64, FA::Type{…}, A::Type{…}, TT::Type, Mode::Enzyme.API.CDerivativeMode, width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, edges::Vector{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:5573
[19] thunk_generator(world::UInt64, source::LineNumberNode, FA::Type, A::Type, TT::Type, Mode::Enzyme.API.CDerivativeMode, Width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, self::Any, fakeworld::Any, fa::Type, a::Type, tt::Type, mode::Type, width::Type, modifiedbetween::Type, returnprimal::Type, shadowinit::Type, abi::Type, erriffuncwritten::Type, runtimeactivity::Type)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:5758
[20] autodiff
@ ~/.julia/packages/Enzyme/R6sE8/src/Enzyme.jl:485 [inlined]
[21] compute_gradients_impl(ad::AutoEnzyme{…}, obj_fn::GenericLossFunction{…}, data::Tuple{…}, ts::Lux.Training.TrainState{…})
@ LuxEnzymeExt ~/.julia/packages/Lux/DHtyL/ext/LuxEnzymeExt/training.jl:8
[22] compute_gradients
@ ~/.julia/packages/Lux/DHtyL/src/helpers/training.jl:200 [inlined]
[23] single_train_step_impl!
@ ~/.julia/packages/Lux/DHtyL/src/helpers/training.jl:320 [inlined]
[24] #single_train_step!#6
@ ~/.julia/packages/Lux/DHtyL/src/helpers/training.jl:288 [inlined]
[25] single_train_step!(backend::AutoEnzyme{…}, obj_fn::GenericLossFunction{…}, data::Tuple{…}, ts::Lux.Training.TrainState{…})
@ Lux.Training ~/.julia/packages/Lux/DHtyL/src/helpers/training.jl:284
[26] main(dev::CUDADevice{…}, rng::Xoshiro, model::Chain{…}, ::Tuple{…}; nepochs::Int64, opt::Adam, backend::AutoEnzyme{…})
@ Main ~/src/lux_enzyme_test.jl:29
[27] main(dev::CUDADevice{Nothing}, rng::Xoshiro, model::Chain{@NamedTuple{…}, Nothing}, ::Tuple{CuArray{…}, CuArray{…}})
@ Main ~/src/lux_enzyme_test.jl:17
[28] macro expansion
@ ./timing.jl:581 [inlined]
[29] top-level scope
@ ./REPL[2]:1
Some type information was truncated. Use `show(err)` to see complete types.
cpu side, all remaining time is kept in the Julia base compiler:
╎ +11 28880 …ase/compiler/abstractinterpretation.jl:2282; abstract_call(interp::Core.Compiler.NativeInterpreter, arginfo
For the CPU the majority of time spent is coming from lacking precompilation statements.
--trace-compile result https://gist.github.com/vchuravy/d2e1a70dbdb4bf76b454867c3b6fd0c2
there's still more to go but @ExpandingMan how does this look now (we've had a lot of nontrivial compile time improvements recently)