Flux.jl
Flux.jl copied to clipboard
Create a flag to use Enzyme as the AD in training/etc.
Motivation and description
Now that all the internal Flux tests pass, we should start setting up for integration. Having such a flag would make it easier for myself and others to test things out, debug, etc.
Possible Implementation
No response
cc @CarloLucibello @ToucheSir
I think the basic interface needed is a nice gradient
function.
This code is still not working though, on both cpu and cuda gpu:
using CUDA # for GPU training
using Flux, Enzyme
using Random, Statistics
_make_zero(x::Union{Number,AbstractArray}) = zero(x)
_make_zero(x) = x
make_zero(model) = fmap(_make_zero, model)
function gradient_ez(f, x...)
args = []
for x in x
if x isa Number
push!(args, Active(x))
else
push!(args, Duplicated(x, make_zero(x)))
end
end
ret = Enzyme.autodiff(ReverseWithPrimal, f, Active, args...)
g = ntuple(i -> x[i] isa Number ? ret[1][i] : args[i].dval, length(x))
return g
end
batch_size = 128
feature_size = 784
num_classes = 10
epochs = 100
device = Flux.cpu # CPU training
# device = Flux.gpu # GPU training
X = randn(Float32, feature_size, batch_size) |> device
y = Flux.onehotbatch(rand(1:num_classes, batch_size), 1:num_classes) |> device
model = Chain(Dense(feature_size => 32, relu),
Dense(32, num_classes)) |> device
opt_state = Flux.setup(Adam(1e-3), model)
loss(model, x, y) = Flux.logitcrossentropy(model(x), y)
accuracy(model, x, y) = mean(Flux.onecold(model(x)) .== Flux.onecold(y))
function report(epoch)
@info "Epoch: $epoch" loss=loss(model, X, y) accuracy=accuracy(model, X, y)
end
report(0)
for epoch in 1:epochs
g = gradient_ez(model -> loss(model, X, y), model)[1] # Enzyme gradient
# g = Flux.gradient(model -> loss(model, X, y), model)[1] # Zygote gradient
Flux.update!(opt_state, model, g)
report(epoch)
end
We should add tests for the loss functions. This one is failing:
gradient_ez(ŷ -> Flux.logitcrossentropy(ŷ, y), randn(Float32, num_classes, batch_size))
A modification to your code above which will be more performant/stable/etc (closures are bad).
In any case still has the same issue and will investigate
# using CUDA # for GPU training
using Flux, Enzyme
using Random, Statistics
_make_zero!(x::AbstractArray) = x .= 0
_make_zero!(x) = x
make_zero!(model) = fmap(_make_zero!, model)
batch_size = 128
feature_size = 784
num_classes = 10
epochs = 100
device = Flux.cpu # CPU training
# device = Flux.gpu # GPU training
X = randn(Float32, feature_size, batch_size) |> device
y = Flux.onehotbatch(rand(1:num_classes, batch_size), 1:num_classes) |> device
model = Chain(Dense(feature_size => 32, relu),
Dense(32, num_classes)) |> device
opt_state = Flux.setup(Adam(1e-3), model)
loss(model, x, y) = Flux.logitcrossentropy(model(x), y)
accuracy(model, x, y) = mean(Flux.onecold(model(x)) .== Flux.onecold(y))
function report(epoch)
@info "Epoch: $epoch" loss=loss(model, X, y) accuracy=accuracy(model, X, y)
end
report(0)
g = deepcopy(model)
for epoch in 1:epochs
make_zero!(g)
Enzyme.autodiff(Reverse, loss, Duplicated(model, g), Const(X), Const(y))
# g = Flux.gradient(model -> loss(model, X, y), model)[1] # Zygote gradient
Flux.update!(opt_state, model, g)
report(epoch)
end
Yeah this works now with the NNlib type stability fix https://github.com/FluxML/NNlib.jl/pull/584
The previous "interface" was to import the corresponding AD package and just call e.g. Tracker.withgradient
.
The most recent attempt was supposed to be DI.jl, but the choice to focus on arrays and single inputs means we can't use it.
To me the best option would be a Flux.gradient
(and Flux.withgradient
) that uses ADTypes.jl (only to avoid further fragmentation). Alternatively, a small package that wraps Enzyme.autodiff
+ make_zero
in a Zygote-like interface (similar to what's above).
But I suggest a dedicated doc page on using Enzyme + Flux will be easier to get through quickly.
Sure, I think docs would be a great first start. I don't really know how to use Flux or where that would go best, so I'll leave that to you.
At the same time, if we're already doing API design, for training it would be nice to not have to constantly reallocate the gradient buffer (with make_zero). I don't know if there's an in-place zeroing function you have for models, but that would be highly beneficial here.
it would be nice to not have to constantly reallocate the gradient buffer
I edited the code in your post to zero the gradient in-place. A slight problem in make_zero!
is that it sets to zero the arrays but not the scalar field, so those are going to be accumulated. That can be fixed later and in principle it is not even a problem since scalars are not updated bu the optimizer.
On gpu I get the following error
error
┌ Warning: active variables passed by value to jl_new_task are not yet supported └ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/kqxyC/src/utils.jl:59 ERROR: Enzyme compilation failed due to illegal type analysis. Current scope: ; Function Attrs: mustprogress willreturn define internal fastcc void @preprocess_julia_fill__33038({} addrspace(10)* nocapture noundef nonnull readonly align 8 dereferenceable(40) "enzyme_type"="{[-1]:Pointer, [-1,0]:Pointer, [-1,0,0]:Pointer, [-1,0,0,0]:Pointer, [-1,0,0,0,0]:Pointer, [-1,0,0,0,8]:Integer, [-1,0,0,0,16]:Pointer, [-1,0,0,16]:Integer, [-1,0,0,17]:Integer, [-1,0,0,18]:Integer, [-1,0,0,19]:Integer, [-1,0,0,20]:Integer, [-1,0,0,21]:Integer, [-1,0,0,22]:Integer, [-1,0,0,23]:Integer, [-1,0,0,24]:Integer, [-1,0,0,32]:Pointer, [-1,0,0,40]:Pointer, [-1,0,0,40,-1]:Integer, [-1,0,8]:Integer, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer}" "enzymejl_parmtype"="139959628162192" "enzymejl_parmtype_ref"="2" %0, float "enzyme_type"="{[-1]:Float@float}" "enzymejl_parmtype"="139978039813152" "enzymejl_parmtype_ref"="0" %1) unnamed_addr #657 !dbg !47671 { top: %2 = call {}*** @julia.get_pgcstack() %3 = call {}*** @julia.get_pgcstack() %4 = bitcast {}*** %2 to {}** %5 = getelementptr inbounds {}*, {}** %4, i64 -14 %6 = getelementptr inbounds {}*, {}** %5, i64 16 %7 = bitcast {}** %6 to i8** %8 = load i8*, i8** %7, align 8 %9 = call noalias nonnull dereferenceable(8) dereferenceable_or_null(8) {} addrspace(10)* @julia.gc_alloc_obj({}** %5, i64 8, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 139961738084176 to {}*) to {} addrspace(10)*)), !enzyme_fromstack !615 call void @zeroType.457({} addrspace(10)* %9, i8 0, i64 8), !enzyme_zerostack !590 %phic1 = bitcast {} addrspace(10)* %9 to {} addrspace(10)* addrspace(10)*, !enzyme_caststack !590 %10 = bitcast {}*** %3 to {}** %11 = getelementptr inbounds {}*, {}** %10, i64 -14 %12 = getelementptr inbounds {}*, {}** %11, i64 16 %13 = bitcast {}** %12 to i8** %14 = load i8*, i8** %13, align 8 %15 = call noalias nonnull dereferenceable(8) dereferenceable_or_null(8) {} addrspace(10)* @julia.gc_alloc_obj({}** %11, i64 8, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 139961738084176 to {}*) to {} addrspace(10)*)), !enzyme_fromstack !615 call void @zeroType.456({} addrspace(10)* %15, i8 0, i64 8), !enzyme_zerostack !590 %phic = bitcast {} addrspace(10)* %15 to {} addrspace(10)* addrspace(10)*, !enzyme_caststack !590 %phic19 = call noalias nonnull dereferenceable(1) dereferenceable_or_null(1) i8* @malloc(i64 1), !enzyme_fromstack !4822 %16 = call {}*** @julia.get_pgcstack() #658 store {} addrspace(10)* null, {} addrspace(10)* addrspace(10)* %phic1, align 8, !noalias !47672 call void ({} addrspace(10)*, ...) @julia.write_barrier({} addrspace(10)* %9, {} addrspace(10)* null) store {} addrspace(10)* null, {} addrspace(10)* addrspace(10)* %phic, align 8, !noalias !47672 call void ({} addrspace(10)*, ...) @julia.write_barrier({} addrspace(10)* %15, {} addrspace(10)* null) %current_task329 = getelementptr inbounds {}**, {}*** %16, i64 -14 %current_task3 = bitcast {}*** %current_task329 to {}** %ptls_field30 = getelementptr inbounds {}**, {}*** %16, i64 2 %17 = bitcast {}*** %ptls_field30 to i64*** %ptls_load3132 = load i64**, i64*** %17, align 8, !tbaa !591 %18 = getelementptr inbounds i64*, i64** %ptls_load3132, i64 2 %safepoint = load i64*, i64** %18, align 8, !tbaa !595 fence syncscope("singlethread") seq_cst call void @julia.safepoint(i64* %safepoint) #658, !dbg !47675 fence syncscope("singlethread") seq_cst %bitcast_coercion = bitcast float %1 to i32, !dbg !47676 %19 = addrspacecast {} addrspace(10)* %0 to {} addrspace(10)* addrspace(11)*, !dbg !47678 %getfield = load atomic {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %19 unordered, align 8, !dbg !47678, !tbaa !602, !alias.scope !606, !noalias !609, !nonnull !590, !dereferenceable !614, !align !615 %20 = addrspacecast {} addrspace(10)* %getfield to i8 addrspace(11)*, !dbg !47681 %21 = getelementptr inbounds i8, i8 addrspace(11)* %20, i64 8, !dbg !47681 %22 = load i8, i8 addrspace(11)* %21, align 8, !dbg !47681, !tbaa !602, !alias.scope !606, !noalias !609 %23 = and i8 %22, 1, !dbg !47681 %.not = icmp eq i8 %23, 0, !dbg !47681 br i1 %.not, label %L8, label %L5, !dbg !47682L5: ; preds = %top %24 = call fastcc [1 x {} addrspace(10)] @julia_ArgumentError_31098({} addrspace(10) nofree noundef nonnull align 16 addrspacecast ({}* inttoptr (i64 139965165787312 to {}) to {} addrspace(10))) #659, !dbg !47683 %box = call noalias nonnull dereferenceable(8) "enzyme_inactive" {} addrspace(10)* @julia.gc_alloc_obj({}** nonnull %current_task3, i64 noundef 8, {} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 139978038671616 to {}) to {} addrspace(10))) #660, !dbg !47683 %25 = bitcast {} addrspace(10)* %box to [1 x {} addrspace(10)] addrspace(10), !dbg !47683 %26 = extractvalue [1 x {} addrspace(10)] %24, 0, !dbg !47683 %27 = getelementptr [1 x {} addrspace(10)], [1 x {} addrspace(10)] addrspace(10) %25, i64 0, i64 0, !dbg !47683 store {} addrspace(10)* %26, {} addrspace(10)* addrspace(10)* %27, align 8, !dbg !47683, !tbaa !621, !alias.scope !606, !noalias !47684 %28 = addrspacecast {} addrspace(10)* %box to {} addrspace(12), !dbg !47683 call void @ijl_throw({} addrspace(12) %28) #661, !dbg !47683 unreachable, !dbg !47683
L8: ; preds = %top %29 = addrspacecast {} addrspace(10)* %getfield to {} addrspace(10)* addrspace(11), !dbg !47685 %getfield6 = load atomic {} addrspace(10), {} addrspace(10)* addrspace(11)* %29 unordered, align 8, !dbg !47685, !tbaa !602, !alias.scope !606, !noalias !609, !nonnull !590, !dereferenceable !628, !align !615 %30 = addrspacecast {} addrspace(10)* %getfield6 to i8 addrspace(11), !dbg !47687 %getfield_addr7 = getelementptr inbounds i8, i8 addrspace(11) %30, i64 40, !dbg !47687 %31 = bitcast i8 addrspace(11)* %getfield_addr7 to {} addrspace(10)* addrspace(11), !dbg !47687 %getfield8 = load atomic {} addrspace(10), {} addrspace(10)* addrspace(11)* %31 unordered, align 8, !dbg !47687, !tbaa !602, !alias.scope !606, !noalias !609, !nonnull !590, !dereferenceable !615, !align !615 %32 = call token (...) @llvm.julia.gc_preserve_begin({} addrspace(10)* nonnull %getfield8) #658, !dbg !47689 %33 = addrspacecast {} addrspace(10)* %getfield8 to {} addrspace(11), !dbg !47690 %34 = call nonnull {} @julia.pointer_from_objref({} addrspace(11)* noundef %33) #662, !dbg !47690 %ptr.i = bitcast {}* %34 to i64*, !dbg !47689 %rv.i = load atomic i64, i64* %ptr.i acquire, align 16, !dbg !47689 call void @llvm.julia.gc_preserve_end(token %32) #658, !dbg !47689 %.not33 = icmp eq i64 %rv.i, 0, !dbg !47692 br i1 %.not33, label %L17, label %L20, !dbg !47688
L17: ; preds = %L8 %35 = call fastcc [1 x {} addrspace(10)] @julia_ArgumentError_31098({} addrspace(10) nofree noundef nonnull align 16 addrspacecast ({}* inttoptr (i64 139965165788400 to {}) to {} addrspace(10))) #658, !dbg !47693 %box11 = call noalias nonnull dereferenceable(8) "enzyme_inactive" {} addrspace(10)* @julia.gc_alloc_obj({}** nonnull %current_task3, i64 noundef 8, {} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 139978038671616 to {}) to {} addrspace(10))) #660, !dbg !47693 %36 = bitcast {} addrspace(10)* %box11 to [1 x {} addrspace(10)] addrspace(10), !dbg !47693 %37 = extractvalue [1 x {} addrspace(10)] %35, 0, !dbg !47693 %38 = getelementptr [1 x {} addrspace(10)], [1 x {} addrspace(10)] addrspace(10) %36, i64 0, i64 0, !dbg !47693 store {} addrspace(10)* %37, {} addrspace(10)* addrspace(10)* %38, align 8, !dbg !47693, !tbaa !621, !alias.scope !606, !noalias !47684 %39 = addrspacecast {} addrspace(10)* %box11 to {} addrspace(12), !dbg !47693 call void @ijl_throw({} addrspace(12) %39) #661, !dbg !47693 unreachable, !dbg !47693
L20: ; preds = %L8 %40 = addrspacecast {} addrspace(10)* %getfield6 to { {} addrspace(10), i64, i64, i8 } addrspace(11), !dbg !47694 %41 = getelementptr inbounds { {} addrspace(10), i64, i64, i8 }, { {} addrspace(10), i64, i64, i8 } addrspace(11)* %40, i64 0, i32 0, !dbg !47694 %42 = load {} addrspace(10), {} addrspace(10) addrspace(11)* %41, align 8, !dbg !47694, !tbaa !602, !alias.scope !606, !noalias !609 %43 = addrspacecast {} addrspace(10)* %42 to i8 addrspace(11), !dbg !47696 %44 = getelementptr inbounds i8, i8 addrspace(11) %43, i64 8, !dbg !47696 %45 = load i8, i8 addrspace(11)* %44, align 8, !dbg !47696, !tbaa !602, !alias.scope !606, !noalias !609 %46 = and i8 %45, 1, !dbg !47696 %.not34 = icmp eq i8 %46, 0, !dbg !47696 br i1 %.not34, label %L73, label %L27, !dbg !47698
L27: ; preds = %L20 %47 = call fastcc nonnull align 8 {} addrspace(10)* @julia_context__32398({} addrspace(10)* noundef nonnull align 8 dereferenceable(24) %42) #658, !dbg !47700 store volatile {} addrspace(10)* %42, {} addrspace(10)* addrspace(10)* %phic, align 8, !dbg !47701, !noalias !47672 call void ({} addrspace(10), ...) @julia.write_barrier({} addrspace(10) %15, {} addrspace(10)* %42), !dbg !47701 store volatile {} addrspace(10)* %47, {} addrspace(10)* addrspace(10)* %phic1, align 8, !dbg !47701, !noalias !47672 call void ({} addrspace(10), ...) @julia.write_barrier({} addrspace(10) %9, {} addrspace(10)* %47), !dbg !47701 store volatile i8 0, i8* %phic19, align 1, !dbg !47701, !tbaa !774, !alias.scope !776, !noalias !47702 %48 = call i64 @ijl_excstack_state() #658, !dbg !47701 %49 = call i32 @julia.except_enter() #663, !dbg !47701 %50 = icmp eq i32 %49, 0, !dbg !47701 br i1 %50, label %try, label %L46, !dbg !47701
L46: ; preds = %L27 %phic.0.phic.0.phic.0.phic.0.phic.0.phic.0.phic.0.phic.0.phic.0.phic.0. = load volatile {} addrspace(10), {} addrspace(10) addrspace(10)* %phic, align 8, !dbg !47703, !nonnull !590 %phic1.0.phic1.0.phic1.0.phic1.0.phic1.0.phic1.0.phic1.0.phic1.0.phic1.0.phic1.0. = load volatile {} addrspace(10), {} addrspace(10) addrspace(10)* %phic1, align 8, !dbg !47703, !nonnull !590 %phic19.0.phic19.0.phic19.0.phic19.0.phic19.0.phic19.0.phic19.0.phic19.0.phic19.0.phic19.0. = load volatile i8, i8* %phic19, align 1, !dbg !47703 call void @ijl_pop_handler(i32 noundef 1) #658, !dbg !47703 %51 = and i8 %phic19.0.phic19.0.phic19.0.phic19.0.phic19.0.phic19.0.phic19.0.phic19.0.phic19.0.phic19.0., 1, !dbg !47703 %phi.cast = icmp ne i8 %51, 0, !dbg !47703 br label %L51, !dbg !47703
L51: ; preds = %try, %L46 %value_phi = phi {} addrspace(10)* [ %42, %try ], [ %phic.0.phic.0.phic.0.phic.0.phic.0.phic.0.phic.0.phic.0.phic.0.phic.0., %L46 ] %value_phi15 = phi {} addrspace(10)* [ %47, %try ], [ %phic1.0.phic1.0.phic1.0.phic1.0.phic1.0.phic1.0.phic1.0.phic1.0.phic1.0.phic1.0., %L46 ] %value_phi17 = phi i1 [ true, %try ], [ %phi.cast, %L46 ] %52 = addrspacecast {} addrspace(10)* %value_phi15 to {} addrspace(11), !dbg !47704 %53 = icmp eq {} addrspace(11) %52, addrspacecast ({}* inttoptr (i64 139978194116616 to {}) to {} addrspace(11)), !dbg !47704 %54 = addrspacecast {} addrspace(10)* %value_phi to {} addrspace(11)* %55 = icmp eq {} addrspace(11)* %52, %54 %or.cond = select i1 %53, i1 true, i1 %55, !dbg !47704 br i1 %or.cond, label %L67, label %L62, !dbg !47704
L62: ; preds = %L51 %56 = addrspacecast {} addrspace(10)* %value_phi15 to i8 addrspace(11), !dbg !47705 %57 = getelementptr inbounds i8, i8 addrspace(11) %56, i64 8, !dbg !47705 %58 = load i8, i8 addrspace(11)* %57, align 8, !dbg !47705, !tbaa !846, !alias.scope !606, !noalias !609 %59 = and i8 %58, 1, !dbg !47705 %.not35 = icmp eq i8 %59, 0, !dbg !47705 br i1 %.not35, label %L67, label %L65, !dbg !47704
L65: ; preds = %L62 %60 = call fastcc nonnull {} addrspace(10)* @julia_context__32398({} addrspace(10)* noundef nonnull align 8 dereferenceable(24) %value_phi15) #658, !dbg !47707 br label %L67, !dbg !47707
L67: ; preds = %L65, %L62, %L51 br i1 %50, label %L71, label %L69, !dbg !47707
L69: ; preds = %L67 call fastcc void @julia_rethrow_31152() #661, !dbg !47707 unreachable, !dbg !47707
L71: ; preds = %L67 br i1 %value_phi17, label %ok, label %err, !dbg !47707
L73: ; preds = %L20 call fastcc void @julia_error_31187({} addrspace(10)* nofree noundef nonnull align 32 addrspacecast ({}* inttoptr (i64 139962719163168 to {}) to {} addrspace(10))) #661, !dbg !47708 unreachable, !dbg !47708
try: ; preds = %L27 %61 = call fastcc i64 @julia_unsafe_convert_32014({} addrspace(10)* nocapture noundef nonnull readonly align 8 dereferenceable(40) %0) #658, !dbg !47709 %62 = addrspacecast {} addrspace(10)* %0 to i8 addrspace(11), !dbg !47713 %63 = getelementptr inbounds i8, i8 addrspace(11) %62, i64 24, !dbg !47713 %aggregate_load_box.sroa.0.0..sroa_idx = bitcast i8 addrspace(11)* %63 to i64 addrspace(11), !dbg !47713 %aggregate_load_box.sroa.0.0.copyload = load i64, i64 addrspace(11) %aggregate_load_box.sroa.0.0..sroa_idx, align 8, !dbg !47713, !tbaa !710, !alias.scope !711, !noalias !47716 %aggregate_load_box.sroa.2.0..sroa_idx25 = getelementptr inbounds i8, i8 addrspace(11)* %62, i64 32, !dbg !47713 %64 = bitcast i8 addrspace(11)* %aggregate_load_box.sroa.2.0..sroa_idx25 to i64 addrspace(11), !dbg !47713 %aggregate_load_box.sroa.2.0.copyload = load i64, i64 addrspace(11) %64, align 8, !dbg !47713, !tbaa !710, !alias.scope !711, !noalias !47716 %65 = mul i64 %aggregate_load_box.sroa.2.0.copyload, %aggregate_load_box.sroa.0.0.copyload, !dbg !47717 call fastcc void @julia_set__33047(i64 zeroext %61, i32 zeroext %bitcast_coercion, i64 signext %65) #658, !dbg !47712 store volatile i8 1, i8* %phic19, align 1, !dbg !47703, !tbaa !774, !alias.scope !776, !noalias !47702 call void @ijl_pop_handler(i32 noundef 1) #658, !dbg !47703 br label %L51, !dbg !47703
err: ; preds = %L71 call void @ijl_undefined_var_error({} addrspace(12)* noundef addrspacecast ({}* inttoptr (i64 139978194630336 to {}) to {} addrspace(12))) #661, !dbg !47707 unreachable, !dbg !47707
ok: ; preds = %L71 ret void, !dbg !47699 }
Type analysis state:
Illegal updateAnalysis prev:{[-1]:Integer} new: {[-1]:Float@float} val: %bitcast_coercion = bitcast float %1 to i32, !dbg !603 origin= %bitcast_coercion = bitcast float %1 to i32, !dbg !603 MethodInstance for fill!(::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::Float32)
Caused by: Stacktrace: [1] reinterpret @ ./essentials.jl:581 [2] fill! @ ~/.julia/packages/CUDA/jdJ7Z/src/array.jl:829
Stacktrace:
[1] julia_error(cstr::Cstring, val::Ptr{…}, errtype::Enzyme.API.ErrorType, data::Ptr{…}, data2::Ptr{…}, B::Ptr{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/2FwRI/src/compiler.jl:1690
[2] EnzymeCreatePrimalAndGradient(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{…}, TA::Enzyme.TypeAnalysis, returnValue::Bool, dretUsed::Bool, mode::Enzyme.API.CDerivativeMode, width::Int64, additionalArg::Ptr{…}, forceAnonymousTape::Bool, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{…}, augmented::Ptr{…}, atomicAdd::Bool)
@ Enzyme.API ~/.julia/packages/Enzyme/2FwRI/src/api.jl:154
[3] enzyme!(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::NTuple{…}, returnPrimal::Bool, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/2FwRI/src/compiler.jl:3177
[4] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/2FwRI/src/compiler.jl:5070
[5] codegen
@ ~/.julia/packages/Enzyme/2FwRI/src/compiler.jl:4477 [inlined]
[6] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/2FwRI/src/compiler.jl:5755
[7] _thunk
@ ~/.julia/packages/Enzyme/2FwRI/src/compiler.jl:5755 [inlined]
[8] cached_compilation
@ ~/.julia/packages/Enzyme/2FwRI/src/compiler.jl:5793 [inlined]
[9] (::Enzyme.Compiler.var"#554#555"{…})(ctx::LLVM.Context)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/2FwRI/src/compiler.jl:5859
[10] JuliaContext(f::Enzyme.Compiler.var"#554#555"{…}; kwargs::@Kwargs{})
@ GPUCompiler ~/.julia/packages/GPUCompiler/kqxyC/src/driver.jl:52
[11] JuliaContext(f::Function)
@ GPUCompiler ~/.julia/packages/GPUCompiler/kqxyC/src/driver.jl:42
[12] #s2027#553
@ ~/.julia/packages/Enzyme/2FwRI/src/compiler.jl:5811 [inlined]
[13]
@ Enzyme.Compiler ./none:0
[14] (::Core.GeneratedFunctionStub)(::UInt64, ::LineNumberNode, ::Any, ::Vararg{Any})
@ Core ./boot.jl:602
[15] autodiff
@ ~/.julia/packages/Enzyme/2FwRI/src/Enzyme.jl:286 [inlined]</
[16] autodiff
@ ~/.julia/packages/Enzyme/2FwRI/src/Enzyme.jl:315 [inlined]
[17] autodiff(::ReverseMode{…}, ::typeof(loss), ::Duplicated{…}, ::Const{…}, ::Const{…})
@ Enzyme ~/.julia/packages/Enzyme/2FwRI/src/Enzyme.jl:300
[18] top-level scope
@ ~/juliadev/Flux/mlp.jl:37
Some type information was truncated. Use show(err)
to see complete types.
You'll need https://github.com/JuliaGPU/CUDA.jl/pull/2371 and then https://github.com/JuliaPackaging/Yggdrasil/pull/8666. It then hits a cublasscal issue, which I stopped investigating to go get dinner.
I think the basic interface needed is a nice gradient function.
Enzyme's own gradient
should now do this, as make_zero
understands nested structures:
julia> sh = [1f0, 2f0]; nt = (a=sh, b=sh, c=copy(sh));
julia> Enzyme.gradient(Reverse, x -> sum(map(sum, x)), nt)
(a = Float32[2.0, 2.0], b = Float32[2.0, 2.0], c = Float32[1.0, 1.0])
(jl_o1ZBlk) pkg> st Enzyme
Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_o1ZBlk/Project.toml`
[7da242da] Enzyme v0.12.4
The above example doesn't work for me, but I believe function gradient_ez(f, x...)
can be deleted to have just this:
for epoch in 1:epochs
g = Enzyme.gradient(Reverse, m -> loss(m, X, y), model) # Enzyme gradient
# g = Flux.gradient(model -> loss(model, X, y), model)[1] # Zygote gradient
Flux.update!(opt_state, model, g)
report(epoch)
end
A slight problem in make_zero! is that it sets to zero the arrays but not the scalar field, so those are going to be accumulated. That can be fixed later and in principle it is not even a problem since scalars are not updated bu the optimizer.
Right. For those coming from Zygote, it's slightly odd that the gradient contains numbers for non-diff things. But I believe Optimisers.jl's idea of what parameters can be updated is narrow enough that it will only use true gradient numbers from Enzyme.jl.
This should be resolved by https://github.com/FluxML/Flux.jl/pull/2446
Like I say in that PR
""" I have no opinions on the design/API and I will give this PR to you all to make it however you feel (and I will go back to staring at CUDA).
I will note that perf atm is unclear and is worth investigating. However, before we do that, having a good way to run/test things is critical, hence this PR. """
edit: accidentally reran cpu, please ignore below.
CUDA works on the simple example now. It does require either CUDA#master on already merged branches or hopefully a backport release from CUDA.jl via https://github.com/JuliaGPU/CUDA.jl/pull/2375 as well as a Enzyme_jll bump
wmoses@beast:~/git/Flux.jl ((HEAD detached at origin/master)) $ cat orig.jl
using CUDA # for GPU training
using Flux, Enzyme
using Random, Statistics
_make_zero(x::Union{Number,AbstractArray}) = zero(x)
_make_zero(x) = x
make_zero(model) = fmap(_make_zero, model)
function gradient_ez(f, x...)
args = []
for x in x
if x isa Number
push!(args, Active(x))
else
push!(args, Duplicated(x, make_zero(x)))
end
end
ret = Enzyme.autodiff(ReverseWithPrimal, f, Active, args...)
g = ntuple(i -> x[i] isa Number ? ret[1][i] : args[i].dval, length(x))
return g
end
batch_size = 128
feature_size = 784
num_classes = 10
epochs = 100
# device = Flux.cpu # CPU training
device = Flux.gpu # GPU training
X = randn(Float32, feature_size, batch_size) |> device
y = Flux.onehotbatch(rand(1:num_classes, batch_size), 1:num_classes) |> device
model = Chain(Dense(feature_size => 32, relu),
Dense(32, num_classes)) |> device
opt_state = Flux.setup(Adam(1e-3), model)
loss(model, x, y) = Flux.logitcrossentropy(model(x), y)
accuracy(model, x, y) = mean(Flux.onecold(model(x)) .== Flux.onecold(y))
function report(epoch)
@info "Epoch: $epoch" loss=loss(model, X, y) accuracy=accuracy(model, X, y)
end
report(0)
for epoch in 1:epochs
g = gradient_ez(model -> loss(model, X, y), model)[1] # Enzyme gradient
# g = Flux.gradient(model -> loss(model, X, y), model)[1] # Zygote gradient
Flux.update!(opt_state, model, g)
report(epoch)
end
wmoses@beast:~/git/Flux.jl ((HEAD detached at origin/master)) $ ~/git/Enzyme.jl/julia-1.10.2/bin/julia --project orig.jl
┌ Warning: Package cuDNN not found in current path.
│ - Run `import Pkg; Pkg.add("cuDNN")` to install the cuDNN package, then restart julia.
│ - If cuDNN is not installed, some Flux functionalities will not be available when running on the GPU.
└ @ FluxCUDAExt ~/git/Flux.jl/ext/FluxCUDAExt/FluxCUDAExt.jl:57
┌ Info: Epoch: 0
│ loss = 2.7904227f0
└ accuracy = 0.125
┌ Info: Epoch: 1
│ loss = 2.5142982f0
└ accuracy = 0.15625
┌ Info: Epoch: 2
│ loss = 2.2610319f0
└ accuracy = 0.203125
┌ Info: Epoch: 3
│ loss = 2.029134f0
└ accuracy = 0.28125
┌ Info: Epoch: 4
│ loss = 1.8172197f0
└ accuracy = 0.3515625
┌ Info: Epoch: 5
│ loss = 1.6268556f0
└ accuracy = 0.4375
┌ Info: Epoch: 6
│ loss = 1.4554112f0
└ accuracy = 0.546875
┌ Info: Epoch: 7
│ loss = 1.3014916f0
└ accuracy = 0.6640625
┌ Info: Epoch: 8
│ loss = 1.163165f0
└ accuracy = 0.7890625
┌ Info: Epoch: 9
│ loss = 1.0413302f0
└ accuracy = 0.8515625
┌ Info: Epoch: 10
│ loss = 0.93555194f0
└ accuracy = 0.8515625
┌ Info: Epoch: 11
│ loss = 0.84206563f0
└ accuracy = 0.8828125
┌ Info: Epoch: 12
│ loss = 0.7600569f0
└ accuracy = 0.90625
┌ Info: Epoch: 13
│ loss = 0.6874082f0
└ accuracy = 0.921875
┌ Info: Epoch: 14
│ loss = 0.6230737f0
└ accuracy = 0.9296875
┌ Info: Epoch: 15
│ loss = 0.5663827f0
└ accuracy = 0.9609375
┌ Info: Epoch: 16
│ loss = 0.5165455f0
└ accuracy = 0.96875
┌ Info: Epoch: 17
│ loss = 0.4719535f0
└ accuracy = 0.96875
┌ Info: Epoch: 18
│ loss = 0.4319139f0
└ accuracy = 0.9765625
┌ Info: Epoch: 19
│ loss = 0.39577293f0
└ accuracy = 0.984375
┌ Info: Epoch: 20
│ loss = 0.36347917f0
└ accuracy = 0.984375
┌ Info: Epoch: 21
│ loss = 0.33449084f0
└ accuracy = 0.9921875
┌ Info: Epoch: 22
│ loss = 0.30846184f0
└ accuracy = 0.9921875
┌ Info: Epoch: 23
│ loss = 0.28476223f0
└ accuracy = 0.9921875
┌ Info: Epoch: 24
│ loss = 0.26318714f0
└ accuracy = 1.0
┌ Info: Epoch: 25
│ loss = 0.24353352f0
└ accuracy = 1.0
┌ Info: Epoch: 26
│ loss = 0.22557218f0
└ accuracy = 1.0
┌ Info: Epoch: 27
│ loss = 0.20921068f0
└ accuracy = 1.0
┌ Info: Epoch: 28
│ loss = 0.19429381f0
└ accuracy = 1.0
┌ Info: Epoch: 29
│ loss = 0.18054952f0
└ accuracy = 1.0
┌ Info: Epoch: 30
│ loss = 0.16796987f0
└ accuracy = 1.0
┌ Info: Epoch: 31
│ loss = 0.1563463f0
└ accuracy = 1.0
┌ Info: Epoch: 32
│ loss = 0.14567412f0
└ accuracy = 1.0
┌ Info: Epoch: 33
│ loss = 0.13588753f0
└ accuracy = 1.0
┌ Info: Epoch: 34
│ loss = 0.12687433f0
└ accuracy = 1.0
┌ Info: Epoch: 35
│ loss = 0.11857266f0
└ accuracy = 1.0
┌ Info: Epoch: 36
│ loss = 0.11093213f0
└ accuracy = 1.0
┌ Info: Epoch: 37
│ loss = 0.103871785f0
└ accuracy = 1.0
┌ Info: Epoch: 38
│ loss = 0.09736837f0
└ accuracy = 1.0
┌ Info: Epoch: 39
│ loss = 0.09138645f0
└ accuracy = 1.0
┌ Info: Epoch: 40
│ loss = 0.08586908f0
└ accuracy = 1.0
┌ Info: Epoch: 41
│ loss = 0.080786735f0
└ accuracy = 1.0
┌ Info: Epoch: 42
│ loss = 0.07610354f0
└ accuracy = 1.0
┌ Info: Epoch: 43
│ loss = 0.07179588f0
└ accuracy = 1.0
┌ Info: Epoch: 44
│ loss = 0.06783663f0
└ accuracy = 1.0
┌ Info: Epoch: 45
│ loss = 0.06419177f0
└ accuracy = 1.0
┌ Info: Epoch: 46
│ loss = 0.060845155f0
└ accuracy = 1.0
┌ Info: Epoch: 47
│ loss = 0.057761367f0
└ accuracy = 1.0
┌ Info: Epoch: 48
│ loss = 0.0549154f0
└ accuracy = 1.0
┌ Info: Epoch: 49
│ loss = 0.05228231f0
└ accuracy = 1.0
┌ Info: Epoch: 50
│ loss = 0.049845647f0
└ accuracy = 1.0
┌ Info: Epoch: 51
│ loss = 0.047589153f0
└ accuracy = 1.0
┌ Info: Epoch: 52
│ loss = 0.045498513f0
└ accuracy = 1.0
┌ Info: Epoch: 53
│ loss = 0.04355742f0
└ accuracy = 1.0
┌ Info: Epoch: 54
│ loss = 0.04175187f0
└ accuracy = 1.0
┌ Info: Epoch: 55
│ loss = 0.04007356f0
└ accuracy = 1.0
┌ Info: Epoch: 56
│ loss = 0.038507923f0
└ accuracy = 1.0
┌ Info: Epoch: 57
│ loss = 0.037045095f0
└ accuracy = 1.0
┌ Info: Epoch: 58
│ loss = 0.035674226f0
└ accuracy = 1.0
┌ Info: Epoch: 59
│ loss = 0.034392048f0
└ accuracy = 1.0
┌ Info: Epoch: 60
│ loss = 0.033194654f0
└ accuracy = 1.0
┌ Info: Epoch: 61
│ loss = 0.032058075f0
└ accuracy = 1.0
┌ Info: Epoch: 62
│ loss = 0.030996136f0
└ accuracy = 1.0
┌ Info: Epoch: 63
│ loss = 0.02999451f0
└ accuracy = 1.0
┌ Info: Epoch: 64
│ loss = 0.029050402f0
└ accuracy = 1.0
┌ Info: Epoch: 65
│ loss = 0.02815985f0
└ accuracy = 1.0
┌ Info: Epoch: 66
│ loss = 0.027319008f0
└ accuracy = 1.0
┌ Info: Epoch: 67
│ loss = 0.02652272f0
└ accuracy = 1.0
┌ Info: Epoch: 68
│ loss = 0.025767544f0
└ accuracy = 1.0
┌ Info: Epoch: 69
│ loss = 0.025051065f0
└ accuracy = 1.0
┌ Info: Epoch: 70
│ loss = 0.024369944f0
└ accuracy = 1.0
┌ Info: Epoch: 71
│ loss = 0.023721226f0
└ accuracy = 1.0
┌ Info: Epoch: 72
│ loss = 0.023103705f0
└ accuracy = 1.0
┌ Info: Epoch: 73
│ loss = 0.022514593f0
└ accuracy = 1.0
┌ Info: Epoch: 74
│ loss = 0.021952922f0
└ accuracy = 1.0
┌ Info: Epoch: 75
│ loss = 0.021417053f0
└ accuracy = 1.0
┌ Info: Epoch: 76
│ loss = 0.020906389f0
└ accuracy = 1.0
┌ Info: Epoch: 77
│ loss = 0.0204159f0
└ accuracy = 1.0
┌ Info: Epoch: 78
│ loss = 0.01994732f0
└ accuracy = 1.0
┌ Info: Epoch: 79
│ loss = 0.01949887f0
└ accuracy = 1.0
┌ Info: Epoch: 80
│ loss = 0.01906871f0
└ accuracy = 1.0
┌ Info: Epoch: 81
│ loss = 0.018656129f0
└ accuracy = 1.0
┌ Info: Epoch: 82
│ loss = 0.018260362f0
└ accuracy = 1.0
┌ Info: Epoch: 83
│ loss = 0.017879806f0
└ accuracy = 1.0
┌ Info: Epoch: 84
│ loss = 0.017513612f0
└ accuracy = 1.0
┌ Info: Epoch: 85
│ loss = 0.017161498f0
└ accuracy = 1.0
┌ Info: Epoch: 86
│ loss = 0.01682241f0
└ accuracy = 1.0
┌ Info: Epoch: 87
│ loss = 0.016495718f0
└ accuracy = 1.0
┌ Info: Epoch: 88
│ loss = 0.016181245f0
└ accuracy = 1.0
┌ Info: Epoch: 89
│ loss = 0.015877243f0
└ accuracy = 1.0
┌ Info: Epoch: 90
│ loss = 0.0155781405f0
└ accuracy = 1.0
┌ Info: Epoch: 91
│ loss = 0.01528422f0
└ accuracy = 1.0
┌ Info: Epoch: 92
│ loss = 0.014997441f0
└ accuracy = 1.0
┌ Info: Epoch: 93
│ loss = 0.014718127f0
└ accuracy = 1.0
┌ Info: Epoch: 94
│ loss = 0.014446221f0
└ accuracy = 1.0
┌ Info: Epoch: 95
│ loss = 0.014181806f0
└ accuracy = 1.0
┌ Info: Epoch: 96
│ loss = 0.013925277f0
└ accuracy = 1.0
┌ Info: Epoch: 97
│ loss = 0.013677116f0
└ accuracy = 1.0
┌ Info: Epoch: 98
│ loss = 0.013437184f0
└ accuracy = 1.0
┌ Info: Epoch: 99
│ loss = 0.013204632f0
└ accuracy = 1.0
┌ Info: Epoch: 100
│ loss = 0.012979296f0
└ accuracy = 1.0
@CarloLucibello this gradient_ez
is very useful. Thanks! Would it be possible to have also option to run Enzyme from Zygote? Or an example similar to that one with gradient_ez
how to add Zygote.@adjoint
such that for one custom Flux layer instead of Zygote, Enzyme is used, but the rest is still Zygote?
I am thinking of some way, we could smoothly transition without switching to one completely?
The most recent attempt was supposed to be DI.jl, but the choice to focus on arrays and single inputs means we can't use it.
@darsnack I'd actually love to revisit the dream of DI + Flux one of these days.
- For multiple inputs, I think I see a way to support additional constant inputs without too much pain (https://github.com/gdalle/DifferentiationInterface.jl/issues/311). Apparently it's what you need for e.g.
X
andy
in training. - For array-only, the trouble is not supporting general structs, it's testing them. We've had this discussion together, and I don't want to commit to something that a) doesn't work for every backend and b) will probably be undertested because arbtrary structs can be, well, arbitrary. In my view, non-arrays cannot be in the DI API because there will be plenty of cases that fail, and it's very hard to say which ones ahead of time.
To me the best option would be a Flux.gradient (and Flux.withgradient) that uses ADTypes.jl (only to avoid further fragmentation). Alternatively, a small package that wraps Enzyme.autodiff + make_zero in a Zygote-like interface (similar to what's above).
Why not create a package named DifferentiationInterfaceForFlux or something, which relies on DI but tests compatibility with Flux layers and makes it part of its API? In other words, if I change something in DI that removes compatibility with Flux layers, the glue package could still be frozen to its current version until it gets resolved.