Enzyme.jl
Enzyme.jl copied to clipboard
Higher order differentiation, `autodiff_deferred` and `CustomRules`
I have been trying to write a custom rule in the context of higher order differentiation. However, I simply cannot get it to work. Happy to try to track this issue down further, but I don't know where to start.
The (relatively) minimal code is the following. The return of the custom rule is not the correct differential, but it is simple enough to try to work with.
using Enzyme
import .EnzymeRules: forward
Enzyme.API.printall!(true)
fun(x, y) = (x - y)^2
function forward(func::Const{typeof(fun)}, o, x, y::Const)
println("Custom Rule")
return Duplicated(x.val + y.val, 1.0)
end
x = 1.
y = 1.1
df(y) = autodiff(Forward, fun, Duplicated, Duplicated(x, 1.0), y)
df(y)
only(autodiff(
Forward,
yt -> autodiff_deferred(Forward, fun, Duplicated, Duplicated(x, 1.0), yt),
Duplicated,
Duplicated(y, 1.0)))
This code returns the following Error
ERROR: KeyError: key "julia_fun_4703" not found
Stacktrace:
[1] getindex
@ ~/.julia/packages/LLVM/Od0DH/src/core/module.jl:245 [inlined]
[2] codegen(output::Symbol, job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/0SYwj/src/compiler.jl:8874
[3] codegen
@ ~/.julia/packages/Enzyme/0SYwj/src/compiler.jl:8723 [inlined]
[4] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/0SYwj/src/compiler.jl:9671
[5] _thunk
@ ~/.julia/packages/Enzyme/0SYwj/src/compiler.jl:9671 [inlined]
[6] cached_compilation
@ ~/.julia/packages/Enzyme/0SYwj/src/compiler.jl:9705 [inlined]
[7] (::Enzyme.Compiler.var"#475#476"{DataType, UnionAll, DataType, Enzyme.API.CDerivativeMode, Tuple{Bool, Bool}, Int64, Bool, Bool, UInt64, DataType})(ctx::LLVM.Context)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/0SYwj/src/compiler.jl:9768
[8] JuliaContext(f::Enzyme.Compiler.var"#475#476"{DataType, UnionAll, DataType, Enzyme.API.CDerivativeMode, Tuple{Bool, Bool}, Int64, Bool, Bool, UInt64, DataType})
@ GPUCompiler ~/.julia/packages/GPUCompiler/YO8Uj/src/driver.jl:47
[9] #s292#474
@ ~/.julia/packages/Enzyme/0SYwj/src/compiler.jl:9723 [inlined]
[10] var"#s292#474"(FA::Any, A::Any, TT::Any, Mode::Any, ModifiedBetween::Any, width::Any, ReturnPrimal::Any, ShadowInit::Any, World::Any, ABI::Any, ::Any, #unused#::Type, #unused#::Type, #unused#::Type, tt::Any, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Any)
@ Enzyme.Compiler ./none:0
[11] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
@ Core ./boot.jl:602
[12] autodiff(#unused#::ForwardMode{FFIABI}, f::Const{var"#25#26"}, #unused#::Type{Duplicated}, args::Duplicated{Float64})
@ Enzyme ~/.julia/packages/Enzyme/0SYwj/src/Enzyme.jl:328
[13] autodiff(::ForwardMode{FFIABI}, ::var"#25#26", ::Type, ::Duplicated{Float64})
@ Enzyme ~/.julia/packages/Enzyme/0SYwj/src/Enzyme.jl:222
and the Enzyme trace
after simplification :
; Function Attrs: mustprogress nofree nosync readnone willreturn
define double @preprocess_julia_fun_4671mustwrap_inner.1(double %0, double %1) local_unnamed_addr #3 {
entry:
%2 = call double @julia_fun_4671(double %0, double %1) #4
ret double %2
}
; Function Attrs: mustprogress nofree nosync readnone willreturn
define internal { double, double } @fwddiffejulia_fun_4671mustwrap_inner.1(double %0, double %"'", double %1) local_unnamed_addr #3 {
entry:
%2 = alloca [2 x double], align 8
%3 = call {}*** @julia.get_pgcstack()
%4 = call {}*** @julia.get_pgcstack()
%5 = bitcast {}*** %4 to {}**
%6 = getelementptr inbounds {}*, {}** %5, i64 -13
%7 = getelementptr inbounds {}*, {}** %6, i64 15
%8 = bitcast {}** %7 to i8**
%9 = load i8*, i8** %8, align 8
%10 = call {} addrspace(10)* @julia.gc_alloc_obj({}** %6, i64 16, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 140675504249552 to {}*) to {} addrspace(10)*))
%11 = bitcast {} addrspace(10)* %10 to [2 x double] addrspace(10)*
%12 = addrspacecast [2 x double] addrspace(10)* %11 to [2 x double] addrspace(11)*
%13 = getelementptr inbounds [2 x double], [2 x double] addrspace(11)* %12, i64 0, i32 0
store double %0, double addrspace(11)* %13, align 8
%14 = getelementptr inbounds [2 x double], [2 x double] addrspace(11)* %12, i64 0, i32 1
store double %"'", double addrspace(11)* %14, align 8
%15 = bitcast {}*** %3 to {}**
%16 = getelementptr inbounds {}*, {}** %15, i64 -13
%17 = getelementptr inbounds {}*, {}** %16, i64 15
%18 = bitcast {}** %17 to i8**
%19 = load i8*, i8** %18, align 8
%20 = call {} addrspace(10)* @julia.gc_alloc_obj({}** %16, i64 8, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 140675500583760 to {}*) to {} addrspace(10)*))
%21 = bitcast {} addrspace(10)* %20 to [1 x double] addrspace(10)*
%22 = addrspacecast [1 x double] addrspace(10)* %21 to [1 x double] addrspace(11)*
%23 = getelementptr inbounds [1 x double], [1 x double] addrspace(11)* %22, i64 0, i32 0
store double %1, double addrspace(11)* %23, align 8
call void @llvm.experimental.noalias.scope.decl(metadata !19)
%24 = call {}*** @julia.get_pgcstack()
call fastcc void @julia_println_4676() #8, !dbg !22, !noalias !19
%25 = getelementptr inbounds [2 x double], [2 x double] addrspace(11)* %12, i64 0, i64 0, !dbg !24
%26 = getelementptr inbounds [1 x double], [1 x double] addrspace(11)* %22, i64 0, i64 0, !dbg !24
%27 = load double, double addrspace(11)* %25, align 8, !dbg !28, !tbaa !30, !alias.scope !34, !noalias !37
%28 = load double, double addrspace(11)* %26, align 8, !dbg !28, !tbaa !30, !alias.scope !34, !noalias !37
%29 = fadd double %27, %28, !dbg !28
%.sroa.0.0..sroa_idx.i = getelementptr inbounds [2 x double], [2 x double]* %2, i64 0, i64 0, !dbg !27
store double %29, double* %.sroa.0.0..sroa_idx.i, align 8, !dbg !27, !alias.scope !19, !noalias !42
%.sroa.2.0..sroa_idx1.i = getelementptr inbounds [2 x double], [2 x double]* %2, i64 0, i64 1, !dbg !27
store double 1.000000e+00, double* %.sroa.2.0..sroa_idx1.i, align 8, !dbg !27, !alias.scope !19, !noalias !42
%30 = load [2 x double], [2 x double]* %2, align 8
%31 = extractvalue [2 x double] %30, 0
%32 = extractvalue [2 x double] %30, 1
%33 = insertvalue { double, double } undef, double %31, 0
%34 = insertvalue { double, double } %33, double %32, 1
ret { double, double } %34
}
after simplification :
; Function Attrs: mustprogress nofree nosync readnone willreturn
define double @preprocess_julia_fun_4681mustwrap_inner.1(double %0, double %1) local_unnamed_addr #3 {
entry:
%2 = call double @julia_fun_4681(double %0, double %1) #4
ret double %2
}
; Function Attrs: mustprogress nofree nosync readnone willreturn
define internal { double, double } @fwddiffejulia_fun_4681mustwrap_inner.1(double %0, double %"'", double %1) local_unnamed_addr #3 {
entry:
%2 = alloca [2 x double], align 8
%3 = call {}*** @julia.get_pgcstack()
%4 = call {}*** @julia.get_pgcstack()
%5 = bitcast {}*** %4 to {}**
%6 = getelementptr inbounds {}*, {}** %5, i64 -13
%7 = getelementptr inbounds {}*, {}** %6, i64 15
%8 = bitcast {}** %7 to i8**
%9 = load i8*, i8** %8, align 8
%10 = call {} addrspace(10)* @julia.gc_alloc_obj({}** %6, i64 16, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 140675504249552 to {}*) to {} addrspace(10)*))
%11 = bitcast {} addrspace(10)* %10 to [2 x double] addrspace(10)*
%12 = addrspacecast [2 x double] addrspace(10)* %11 to [2 x double] addrspace(11)*
%13 = getelementptr inbounds [2 x double], [2 x double] addrspace(11)* %12, i64 0, i32 0
store double %0, double addrspace(11)* %13, align 8
%14 = getelementptr inbounds [2 x double], [2 x double] addrspace(11)* %12, i64 0, i32 1
store double %"'", double addrspace(11)* %14, align 8
%15 = bitcast {}*** %3 to {}**
%16 = getelementptr inbounds {}*, {}** %15, i64 -13
%17 = getelementptr inbounds {}*, {}** %16, i64 15
%18 = bitcast {}** %17 to i8**
%19 = load i8*, i8** %18, align 8
%20 = call {} addrspace(10)* @julia.gc_alloc_obj({}** %16, i64 8, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 140675500583760 to {}*) to {} addrspace(10)*))
%21 = bitcast {} addrspace(10)* %20 to [1 x double] addrspace(10)*
%22 = addrspacecast [1 x double] addrspace(10)* %21 to [1 x double] addrspace(11)*
%23 = getelementptr inbounds [1 x double], [1 x double] addrspace(11)* %22, i64 0, i32 0
store double %1, double addrspace(11)* %23, align 8
call void @llvm.experimental.noalias.scope.decl(metadata !19)
%24 = call {}*** @julia.get_pgcstack()
call fastcc void @julia_println_4686() #8, !dbg !22, !noalias !19
%25 = getelementptr inbounds [2 x double], [2 x double] addrspace(11)* %12, i64 0, i64 0, !dbg !24
%26 = getelementptr inbounds [1 x double], [1 x double] addrspace(11)* %22, i64 0, i64 0, !dbg !24
%27 = load double, double addrspace(11)* %25, align 8, !dbg !28, !tbaa !30, !alias.scope !34, !noalias !37
%28 = load double, double addrspace(11)* %26, align 8, !dbg !28, !tbaa !30, !alias.scope !34, !noalias !37
%29 = fadd double %27, %28, !dbg !28
%.sroa.0.0..sroa_idx.i = getelementptr inbounds [2 x double], [2 x double]* %2, i64 0, i64 0, !dbg !27
store double %29, double* %.sroa.0.0..sroa_idx.i, align 8, !dbg !27, !alias.scope !19, !noalias !42
%.sroa.2.0..sroa_idx1.i = getelementptr inbounds [2 x double], [2 x double]* %2, i64 0, i64 1, !dbg !27
store double 1.000000e+00, double* %.sroa.2.0..sroa_idx1.i, align 8, !dbg !27, !alias.scope !19, !noalias !42
%30 = load [2 x double], [2 x double]* %2, align 8
%31 = extractvalue [2 x double] %30, 0
%32 = extractvalue [2 x double] %30, 1
%33 = insertvalue { double, double } undef, double %31, 0
%34 = insertvalue { double, double } %33, double %32, 1
ret { double, double } %34
}
after simplification :
; Function Attrs: mustprogress nofree nosync readnone willreturn
define double @preprocess_julia_fun_4693mustwrap_inner.1(double %0, double %1) local_unnamed_addr #3 {
entry:
%2 = call double @julia_fun_4693(double %0, double %1) #4
ret double %2
}
; Function Attrs: mustprogress nofree nosync readnone willreturn
define internal { double, double } @fwddiffejulia_fun_4693mustwrap_inner.1(double %0, double %"'", double %1) local_unnamed_addr #3 {
entry:
%2 = alloca [2 x double], align 8
%3 = call {}*** @julia.get_pgcstack()
%4 = call {}*** @julia.get_pgcstack()
%5 = bitcast {}*** %4 to {}**
%6 = getelementptr inbounds {}*, {}** %5, i64 -13
%7 = getelementptr inbounds {}*, {}** %6, i64 15
%8 = bitcast {}** %7 to i8**
%9 = load i8*, i8** %8, align 8
%10 = call {} addrspace(10)* @julia.gc_alloc_obj({}** %6, i64 16, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 140675504249552 to {}*) to {} addrspace(10)*))
%11 = bitcast {} addrspace(10)* %10 to [2 x double] addrspace(10)*
%12 = addrspacecast [2 x double] addrspace(10)* %11 to [2 x double] addrspace(11)*
%13 = getelementptr inbounds [2 x double], [2 x double] addrspace(11)* %12, i64 0, i32 0
store double %0, double addrspace(11)* %13, align 8
%14 = getelementptr inbounds [2 x double], [2 x double] addrspace(11)* %12, i64 0, i32 1
store double %"'", double addrspace(11)* %14, align 8
%15 = bitcast {}*** %3 to {}**
%16 = getelementptr inbounds {}*, {}** %15, i64 -13
%17 = getelementptr inbounds {}*, {}** %16, i64 15
%18 = bitcast {}** %17 to i8**
%19 = load i8*, i8** %18, align 8
%20 = call {} addrspace(10)* @julia.gc_alloc_obj({}** %16, i64 8, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 140675500583760 to {}*) to {} addrspace(10)*))
%21 = bitcast {} addrspace(10)* %20 to [1 x double] addrspace(10)*
%22 = addrspacecast [1 x double] addrspace(10)* %21 to [1 x double] addrspace(11)*
%23 = getelementptr inbounds [1 x double], [1 x double] addrspace(11)* %22, i64 0, i32 0
store double %1, double addrspace(11)* %23, align 8
call void @llvm.experimental.noalias.scope.decl(metadata !19)
%24 = call {}*** @julia.get_pgcstack()
call fastcc void @julia_println_4698() #8, !dbg !22, !noalias !19
%25 = getelementptr inbounds [2 x double], [2 x double] addrspace(11)* %12, i64 0, i64 0, !dbg !24
%26 = getelementptr inbounds [1 x double], [1 x double] addrspace(11)* %22, i64 0, i64 0, !dbg !24
%27 = load double, double addrspace(11)* %25, align 8, !dbg !28, !tbaa !30, !alias.scope !34, !noalias !37
%28 = load double, double addrspace(11)* %26, align 8, !dbg !28, !tbaa !30, !alias.scope !34, !noalias !37
%29 = fadd double %27, %28, !dbg !28
%.sroa.0.0..sroa_idx.i = getelementptr inbounds [2 x double], [2 x double]* %2, i64 0, i64 0, !dbg !27
store double %29, double* %.sroa.0.0..sroa_idx.i, align 8, !dbg !27, !alias.scope !19, !noalias !42
%.sroa.2.0..sroa_idx1.i = getelementptr inbounds [2 x double], [2 x double]* %2, i64 0, i64 1, !dbg !27
store double 1.000000e+00, double* %.sroa.2.0..sroa_idx1.i, align 8, !dbg !27, !alias.scope !19, !noalias !42
%30 = load [2 x double], [2 x double]* %2, align 8
%31 = extractvalue [2 x double] %30, 0
%32 = extractvalue [2 x double] %30, 1
%33 = insertvalue { double, double } undef, double %31, 0
%34 = insertvalue { double, double } %33, double %32, 1
ret { double, double } %34
}
after simplification :
; Function Attrs: mustprogress nofree nosync readnone willreturn
define double @preprocess_julia_fun_4703mustwrap_inner.1(double %0, double %1) local_unnamed_addr #3 {
entry:
%2 = call double @julia_fun_4703(double %0, double %1) #4
ret double %2
}
; Function Attrs: mustprogress nofree nosync readnone willreturn
define internal { double, double } @fwddiffejulia_fun_4703mustwrap_inner.1(double %0, double %"'", double %1) local_unnamed_addr #3 {
entry:
%2 = alloca [2 x double], align 8
%3 = call {}*** @julia.get_pgcstack()
%4 = call {}*** @julia.get_pgcstack()
%5 = bitcast {}*** %4 to {}**
%6 = getelementptr inbounds {}*, {}** %5, i64 -13
%7 = getelementptr inbounds {}*, {}** %6, i64 15
%8 = bitcast {}** %7 to i8**
%9 = load i8*, i8** %8, align 8
%10 = call {} addrspace(10)* @julia.gc_alloc_obj({}** %6, i64 16, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 140675504249552 to {}*) to {} addrspace(10)*))
%11 = bitcast {} addrspace(10)* %10 to [2 x double] addrspace(10)*
%12 = addrspacecast [2 x double] addrspace(10)* %11 to [2 x double] addrspace(11)*
%13 = getelementptr inbounds [2 x double], [2 x double] addrspace(11)* %12, i64 0, i32 0
store double %0, double addrspace(11)* %13, align 8
%14 = getelementptr inbounds [2 x double], [2 x double] addrspace(11)* %12, i64 0, i32 1
store double %"'", double addrspace(11)* %14, align 8
%15 = bitcast {}*** %3 to {}**
%16 = getelementptr inbounds {}*, {}** %15, i64 -13
%17 = getelementptr inbounds {}*, {}** %16, i64 15
%18 = bitcast {}** %17 to i8**
%19 = load i8*, i8** %18, align 8
%20 = call {} addrspace(10)* @julia.gc_alloc_obj({}** %16, i64 8, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 140675500583760 to {}*) to {} addrspace(10)*))
%21 = bitcast {} addrspace(10)* %20 to [1 x double] addrspace(10)*
%22 = addrspacecast [1 x double] addrspace(10)* %21 to [1 x double] addrspace(11)*
%23 = getelementptr inbounds [1 x double], [1 x double] addrspace(11)* %22, i64 0, i32 0
store double %1, double addrspace(11)* %23, align 8
call void @llvm.experimental.noalias.scope.decl(metadata !19)
%24 = call {}*** @julia.get_pgcstack()
call fastcc void @julia_println_4708() #8, !dbg !22, !noalias !19
%25 = getelementptr inbounds [2 x double], [2 x double] addrspace(11)* %12, i64 0, i64 0, !dbg !24
%26 = getelementptr inbounds [1 x double], [1 x double] addrspace(11)* %22, i64 0, i64 0, !dbg !24
%27 = load double, double addrspace(11)* %25, align 8, !dbg !28, !tbaa !30, !alias.scope !34, !noalias !37
%28 = load double, double addrspace(11)* %26, align 8, !dbg !28, !tbaa !30, !alias.scope !34, !noalias !37
%29 = fadd double %27, %28, !dbg !28
%.sroa.0.0..sroa_idx.i = getelementptr inbounds [2 x double], [2 x double]* %2, i64 0, i64 0, !dbg !27
store double %29, double* %.sroa.0.0..sroa_idx.i, align 8, !dbg !27, !alias.scope !19, !noalias !42
%.sroa.2.0..sroa_idx1.i = getelementptr inbounds [2 x double], [2 x double]* %2, i64 0, i64 1, !dbg !27
store double 1.000000e+00, double* %.sroa.2.0..sroa_idx1.i, align 8, !dbg !27, !alias.scope !19, !noalias !42
%30 = load [2 x double], [2 x double]* %2, align 8
%31 = extractvalue [2 x double] %30, 0
%32 = extractvalue [2 x double] %30, 1
%33 = insertvalue { double, double } undef, double %31, 0
%34 = insertvalue { double, double } %33, double %32, 1
ret { double, double } %34
}
So at the moment, I don't think the CustomRules are guaranteed to apply at the higher order level, which is somethign we need to do.