Enzyme.jl icon indicating copy to clipboard operation
Enzyme.jl copied to clipboard

Higher order differentiation, `autodiff_deferred` and `CustomRules`

Open Crown421 opened this issue 1 year ago • 1 comments

I have been trying to write a custom rule in the context of higher order differentiation. However, I simply cannot get it to work. Happy to try to track this issue down further, but I don't know where to start.

The (relatively) minimal code is the following. The return of the custom rule is not the correct differential, but it is simple enough to try to work with.

using Enzyme
import .EnzymeRules: forward
Enzyme.API.printall!(true)

fun(x, y) = (x - y)^2

function forward(func::Const{typeof(fun)}, o, x, y::Const)
    println("Custom Rule")
    return Duplicated(x.val + y.val, 1.0)
end

x = 1.
y = 1.1

df(y) = autodiff(Forward, fun, Duplicated, Duplicated(x, 1.0), y)
df(y)

only(autodiff(
    Forward,
    yt -> autodiff_deferred(Forward, fun, Duplicated, Duplicated(x, 1.0), yt),
    Duplicated,
    Duplicated(y, 1.0)))

This code returns the following Error

ERROR: KeyError: key "julia_fun_4703" not found
Stacktrace:
  [1] getindex
    @ ~/.julia/packages/LLVM/Od0DH/src/core/module.jl:245 [inlined]
  [2] codegen(output::Symbol, job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/0SYwj/src/compiler.jl:8874
  [3] codegen
    @ ~/.julia/packages/Enzyme/0SYwj/src/compiler.jl:8723 [inlined]
  [4] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/0SYwj/src/compiler.jl:9671
  [5] _thunk
    @ ~/.julia/packages/Enzyme/0SYwj/src/compiler.jl:9671 [inlined]
  [6] cached_compilation
    @ ~/.julia/packages/Enzyme/0SYwj/src/compiler.jl:9705 [inlined]
  [7] (::Enzyme.Compiler.var"#475#476"{DataType, UnionAll, DataType, Enzyme.API.CDerivativeMode, Tuple{Bool, Bool}, Int64, Bool, Bool, UInt64, DataType})(ctx::LLVM.Context)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/0SYwj/src/compiler.jl:9768
  [8] JuliaContext(f::Enzyme.Compiler.var"#475#476"{DataType, UnionAll, DataType, Enzyme.API.CDerivativeMode, Tuple{Bool, Bool}, Int64, Bool, Bool, UInt64, DataType})
    @ GPUCompiler ~/.julia/packages/GPUCompiler/YO8Uj/src/driver.jl:47
  [9] #s292#474
    @ ~/.julia/packages/Enzyme/0SYwj/src/compiler.jl:9723 [inlined]
 [10] var"#s292#474"(FA::Any, A::Any, TT::Any, Mode::Any, ModifiedBetween::Any, width::Any, ReturnPrimal::Any, ShadowInit::Any, World::Any, ABI::Any, ::Any, #unused#::Type, #unused#::Type, #unused#::Type, tt::Any, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Any)
    @ Enzyme.Compiler ./none:0
 [11] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
    @ Core ./boot.jl:602
 [12] autodiff(#unused#::ForwardMode{FFIABI}, f::Const{var"#25#26"}, #unused#::Type{Duplicated}, args::Duplicated{Float64})
    @ Enzyme ~/.julia/packages/Enzyme/0SYwj/src/Enzyme.jl:328
 [13] autodiff(::ForwardMode{FFIABI}, ::var"#25#26", ::Type, ::Duplicated{Float64})
    @ Enzyme ~/.julia/packages/Enzyme/0SYwj/src/Enzyme.jl:222

and the Enzyme trace

after simplification :
; Function Attrs: mustprogress nofree nosync readnone willreturn
define double @preprocess_julia_fun_4671mustwrap_inner.1(double %0, double %1) local_unnamed_addr #3 {
entry:
  %2 = call double @julia_fun_4671(double %0, double %1) #4
  ret double %2
}

; Function Attrs: mustprogress nofree nosync readnone willreturn
define internal { double, double } @fwddiffejulia_fun_4671mustwrap_inner.1(double %0, double %"'", double %1) local_unnamed_addr #3 {
entry:
  %2 = alloca [2 x double], align 8
  %3 = call {}*** @julia.get_pgcstack()
  %4 = call {}*** @julia.get_pgcstack()
  %5 = bitcast {}*** %4 to {}**
  %6 = getelementptr inbounds {}*, {}** %5, i64 -13
  %7 = getelementptr inbounds {}*, {}** %6, i64 15
  %8 = bitcast {}** %7 to i8**
  %9 = load i8*, i8** %8, align 8
  %10 = call {} addrspace(10)* @julia.gc_alloc_obj({}** %6, i64 16, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 140675504249552 to {}*) to {} addrspace(10)*))
  %11 = bitcast {} addrspace(10)* %10 to [2 x double] addrspace(10)*
  %12 = addrspacecast [2 x double] addrspace(10)* %11 to [2 x double] addrspace(11)*
  %13 = getelementptr inbounds [2 x double], [2 x double] addrspace(11)* %12, i64 0, i32 0
  store double %0, double addrspace(11)* %13, align 8
  %14 = getelementptr inbounds [2 x double], [2 x double] addrspace(11)* %12, i64 0, i32 1
  store double %"'", double addrspace(11)* %14, align 8
  %15 = bitcast {}*** %3 to {}**
  %16 = getelementptr inbounds {}*, {}** %15, i64 -13
  %17 = getelementptr inbounds {}*, {}** %16, i64 15
  %18 = bitcast {}** %17 to i8**
  %19 = load i8*, i8** %18, align 8
  %20 = call {} addrspace(10)* @julia.gc_alloc_obj({}** %16, i64 8, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 140675500583760 to {}*) to {} addrspace(10)*))
  %21 = bitcast {} addrspace(10)* %20 to [1 x double] addrspace(10)*
  %22 = addrspacecast [1 x double] addrspace(10)* %21 to [1 x double] addrspace(11)*
  %23 = getelementptr inbounds [1 x double], [1 x double] addrspace(11)* %22, i64 0, i32 0
  store double %1, double addrspace(11)* %23, align 8
  call void @llvm.experimental.noalias.scope.decl(metadata !19)
  %24 = call {}*** @julia.get_pgcstack()
  call fastcc void @julia_println_4676() #8, !dbg !22, !noalias !19
  %25 = getelementptr inbounds [2 x double], [2 x double] addrspace(11)* %12, i64 0, i64 0, !dbg !24
  %26 = getelementptr inbounds [1 x double], [1 x double] addrspace(11)* %22, i64 0, i64 0, !dbg !24
  %27 = load double, double addrspace(11)* %25, align 8, !dbg !28, !tbaa !30, !alias.scope !34, !noalias !37
  %28 = load double, double addrspace(11)* %26, align 8, !dbg !28, !tbaa !30, !alias.scope !34, !noalias !37
  %29 = fadd double %27, %28, !dbg !28
  %.sroa.0.0..sroa_idx.i = getelementptr inbounds [2 x double], [2 x double]* %2, i64 0, i64 0, !dbg !27
  store double %29, double* %.sroa.0.0..sroa_idx.i, align 8, !dbg !27, !alias.scope !19, !noalias !42
  %.sroa.2.0..sroa_idx1.i = getelementptr inbounds [2 x double], [2 x double]* %2, i64 0, i64 1, !dbg !27
  store double 1.000000e+00, double* %.sroa.2.0..sroa_idx1.i, align 8, !dbg !27, !alias.scope !19, !noalias !42
  %30 = load [2 x double], [2 x double]* %2, align 8
  %31 = extractvalue [2 x double] %30, 0
  %32 = extractvalue [2 x double] %30, 1
  %33 = insertvalue { double, double } undef, double %31, 0
  %34 = insertvalue { double, double } %33, double %32, 1
  ret { double, double } %34
}

after simplification :
; Function Attrs: mustprogress nofree nosync readnone willreturn
define double @preprocess_julia_fun_4681mustwrap_inner.1(double %0, double %1) local_unnamed_addr #3 {
entry:
  %2 = call double @julia_fun_4681(double %0, double %1) #4
  ret double %2
}

; Function Attrs: mustprogress nofree nosync readnone willreturn
define internal { double, double } @fwddiffejulia_fun_4681mustwrap_inner.1(double %0, double %"'", double %1) local_unnamed_addr #3 {
entry:
  %2 = alloca [2 x double], align 8
  %3 = call {}*** @julia.get_pgcstack()
  %4 = call {}*** @julia.get_pgcstack()
  %5 = bitcast {}*** %4 to {}**
  %6 = getelementptr inbounds {}*, {}** %5, i64 -13
  %7 = getelementptr inbounds {}*, {}** %6, i64 15
  %8 = bitcast {}** %7 to i8**
  %9 = load i8*, i8** %8, align 8
  %10 = call {} addrspace(10)* @julia.gc_alloc_obj({}** %6, i64 16, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 140675504249552 to {}*) to {} addrspace(10)*))
  %11 = bitcast {} addrspace(10)* %10 to [2 x double] addrspace(10)*
  %12 = addrspacecast [2 x double] addrspace(10)* %11 to [2 x double] addrspace(11)*
  %13 = getelementptr inbounds [2 x double], [2 x double] addrspace(11)* %12, i64 0, i32 0
  store double %0, double addrspace(11)* %13, align 8
  %14 = getelementptr inbounds [2 x double], [2 x double] addrspace(11)* %12, i64 0, i32 1
  store double %"'", double addrspace(11)* %14, align 8
  %15 = bitcast {}*** %3 to {}**
  %16 = getelementptr inbounds {}*, {}** %15, i64 -13
  %17 = getelementptr inbounds {}*, {}** %16, i64 15
  %18 = bitcast {}** %17 to i8**
  %19 = load i8*, i8** %18, align 8
  %20 = call {} addrspace(10)* @julia.gc_alloc_obj({}** %16, i64 8, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 140675500583760 to {}*) to {} addrspace(10)*))
  %21 = bitcast {} addrspace(10)* %20 to [1 x double] addrspace(10)*
  %22 = addrspacecast [1 x double] addrspace(10)* %21 to [1 x double] addrspace(11)*
  %23 = getelementptr inbounds [1 x double], [1 x double] addrspace(11)* %22, i64 0, i32 0
  store double %1, double addrspace(11)* %23, align 8
  call void @llvm.experimental.noalias.scope.decl(metadata !19)
  %24 = call {}*** @julia.get_pgcstack()
  call fastcc void @julia_println_4686() #8, !dbg !22, !noalias !19
  %25 = getelementptr inbounds [2 x double], [2 x double] addrspace(11)* %12, i64 0, i64 0, !dbg !24
  %26 = getelementptr inbounds [1 x double], [1 x double] addrspace(11)* %22, i64 0, i64 0, !dbg !24
  %27 = load double, double addrspace(11)* %25, align 8, !dbg !28, !tbaa !30, !alias.scope !34, !noalias !37
  %28 = load double, double addrspace(11)* %26, align 8, !dbg !28, !tbaa !30, !alias.scope !34, !noalias !37
  %29 = fadd double %27, %28, !dbg !28
  %.sroa.0.0..sroa_idx.i = getelementptr inbounds [2 x double], [2 x double]* %2, i64 0, i64 0, !dbg !27
  store double %29, double* %.sroa.0.0..sroa_idx.i, align 8, !dbg !27, !alias.scope !19, !noalias !42
  %.sroa.2.0..sroa_idx1.i = getelementptr inbounds [2 x double], [2 x double]* %2, i64 0, i64 1, !dbg !27
  store double 1.000000e+00, double* %.sroa.2.0..sroa_idx1.i, align 8, !dbg !27, !alias.scope !19, !noalias !42
  %30 = load [2 x double], [2 x double]* %2, align 8
  %31 = extractvalue [2 x double] %30, 0
  %32 = extractvalue [2 x double] %30, 1
  %33 = insertvalue { double, double } undef, double %31, 0
  %34 = insertvalue { double, double } %33, double %32, 1
  ret { double, double } %34
}

after simplification :
; Function Attrs: mustprogress nofree nosync readnone willreturn
define double @preprocess_julia_fun_4693mustwrap_inner.1(double %0, double %1) local_unnamed_addr #3 {
entry:
  %2 = call double @julia_fun_4693(double %0, double %1) #4
  ret double %2
}

; Function Attrs: mustprogress nofree nosync readnone willreturn
define internal { double, double } @fwddiffejulia_fun_4693mustwrap_inner.1(double %0, double %"'", double %1) local_unnamed_addr #3 {
entry:
  %2 = alloca [2 x double], align 8
  %3 = call {}*** @julia.get_pgcstack()
  %4 = call {}*** @julia.get_pgcstack()
  %5 = bitcast {}*** %4 to {}**
  %6 = getelementptr inbounds {}*, {}** %5, i64 -13
  %7 = getelementptr inbounds {}*, {}** %6, i64 15
  %8 = bitcast {}** %7 to i8**
  %9 = load i8*, i8** %8, align 8
  %10 = call {} addrspace(10)* @julia.gc_alloc_obj({}** %6, i64 16, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 140675504249552 to {}*) to {} addrspace(10)*))
  %11 = bitcast {} addrspace(10)* %10 to [2 x double] addrspace(10)*
  %12 = addrspacecast [2 x double] addrspace(10)* %11 to [2 x double] addrspace(11)*
  %13 = getelementptr inbounds [2 x double], [2 x double] addrspace(11)* %12, i64 0, i32 0
  store double %0, double addrspace(11)* %13, align 8
  %14 = getelementptr inbounds [2 x double], [2 x double] addrspace(11)* %12, i64 0, i32 1
  store double %"'", double addrspace(11)* %14, align 8
  %15 = bitcast {}*** %3 to {}**
  %16 = getelementptr inbounds {}*, {}** %15, i64 -13
  %17 = getelementptr inbounds {}*, {}** %16, i64 15
  %18 = bitcast {}** %17 to i8**
  %19 = load i8*, i8** %18, align 8
  %20 = call {} addrspace(10)* @julia.gc_alloc_obj({}** %16, i64 8, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 140675500583760 to {}*) to {} addrspace(10)*))
  %21 = bitcast {} addrspace(10)* %20 to [1 x double] addrspace(10)*
  %22 = addrspacecast [1 x double] addrspace(10)* %21 to [1 x double] addrspace(11)*
  %23 = getelementptr inbounds [1 x double], [1 x double] addrspace(11)* %22, i64 0, i32 0
  store double %1, double addrspace(11)* %23, align 8
  call void @llvm.experimental.noalias.scope.decl(metadata !19)
  %24 = call {}*** @julia.get_pgcstack()
  call fastcc void @julia_println_4698() #8, !dbg !22, !noalias !19
  %25 = getelementptr inbounds [2 x double], [2 x double] addrspace(11)* %12, i64 0, i64 0, !dbg !24
  %26 = getelementptr inbounds [1 x double], [1 x double] addrspace(11)* %22, i64 0, i64 0, !dbg !24
  %27 = load double, double addrspace(11)* %25, align 8, !dbg !28, !tbaa !30, !alias.scope !34, !noalias !37
  %28 = load double, double addrspace(11)* %26, align 8, !dbg !28, !tbaa !30, !alias.scope !34, !noalias !37
  %29 = fadd double %27, %28, !dbg !28
  %.sroa.0.0..sroa_idx.i = getelementptr inbounds [2 x double], [2 x double]* %2, i64 0, i64 0, !dbg !27
  store double %29, double* %.sroa.0.0..sroa_idx.i, align 8, !dbg !27, !alias.scope !19, !noalias !42
  %.sroa.2.0..sroa_idx1.i = getelementptr inbounds [2 x double], [2 x double]* %2, i64 0, i64 1, !dbg !27
  store double 1.000000e+00, double* %.sroa.2.0..sroa_idx1.i, align 8, !dbg !27, !alias.scope !19, !noalias !42
  %30 = load [2 x double], [2 x double]* %2, align 8
  %31 = extractvalue [2 x double] %30, 0
  %32 = extractvalue [2 x double] %30, 1
  %33 = insertvalue { double, double } undef, double %31, 0
  %34 = insertvalue { double, double } %33, double %32, 1
  ret { double, double } %34
}

after simplification :
; Function Attrs: mustprogress nofree nosync readnone willreturn
define double @preprocess_julia_fun_4703mustwrap_inner.1(double %0, double %1) local_unnamed_addr #3 {
entry:
  %2 = call double @julia_fun_4703(double %0, double %1) #4
  ret double %2
}

; Function Attrs: mustprogress nofree nosync readnone willreturn
define internal { double, double } @fwddiffejulia_fun_4703mustwrap_inner.1(double %0, double %"'", double %1) local_unnamed_addr #3 {
entry:
  %2 = alloca [2 x double], align 8
  %3 = call {}*** @julia.get_pgcstack()
  %4 = call {}*** @julia.get_pgcstack()
  %5 = bitcast {}*** %4 to {}**
  %6 = getelementptr inbounds {}*, {}** %5, i64 -13
  %7 = getelementptr inbounds {}*, {}** %6, i64 15
  %8 = bitcast {}** %7 to i8**
  %9 = load i8*, i8** %8, align 8
  %10 = call {} addrspace(10)* @julia.gc_alloc_obj({}** %6, i64 16, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 140675504249552 to {}*) to {} addrspace(10)*))
  %11 = bitcast {} addrspace(10)* %10 to [2 x double] addrspace(10)*
  %12 = addrspacecast [2 x double] addrspace(10)* %11 to [2 x double] addrspace(11)*
  %13 = getelementptr inbounds [2 x double], [2 x double] addrspace(11)* %12, i64 0, i32 0
  store double %0, double addrspace(11)* %13, align 8
  %14 = getelementptr inbounds [2 x double], [2 x double] addrspace(11)* %12, i64 0, i32 1
  store double %"'", double addrspace(11)* %14, align 8
  %15 = bitcast {}*** %3 to {}**
  %16 = getelementptr inbounds {}*, {}** %15, i64 -13
  %17 = getelementptr inbounds {}*, {}** %16, i64 15
  %18 = bitcast {}** %17 to i8**
  %19 = load i8*, i8** %18, align 8
  %20 = call {} addrspace(10)* @julia.gc_alloc_obj({}** %16, i64 8, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 140675500583760 to {}*) to {} addrspace(10)*))
  %21 = bitcast {} addrspace(10)* %20 to [1 x double] addrspace(10)*
  %22 = addrspacecast [1 x double] addrspace(10)* %21 to [1 x double] addrspace(11)*
  %23 = getelementptr inbounds [1 x double], [1 x double] addrspace(11)* %22, i64 0, i32 0
  store double %1, double addrspace(11)* %23, align 8
  call void @llvm.experimental.noalias.scope.decl(metadata !19)
  %24 = call {}*** @julia.get_pgcstack()
  call fastcc void @julia_println_4708() #8, !dbg !22, !noalias !19
  %25 = getelementptr inbounds [2 x double], [2 x double] addrspace(11)* %12, i64 0, i64 0, !dbg !24
  %26 = getelementptr inbounds [1 x double], [1 x double] addrspace(11)* %22, i64 0, i64 0, !dbg !24
  %27 = load double, double addrspace(11)* %25, align 8, !dbg !28, !tbaa !30, !alias.scope !34, !noalias !37
  %28 = load double, double addrspace(11)* %26, align 8, !dbg !28, !tbaa !30, !alias.scope !34, !noalias !37
  %29 = fadd double %27, %28, !dbg !28
  %.sroa.0.0..sroa_idx.i = getelementptr inbounds [2 x double], [2 x double]* %2, i64 0, i64 0, !dbg !27
  store double %29, double* %.sroa.0.0..sroa_idx.i, align 8, !dbg !27, !alias.scope !19, !noalias !42
  %.sroa.2.0..sroa_idx1.i = getelementptr inbounds [2 x double], [2 x double]* %2, i64 0, i64 1, !dbg !27
  store double 1.000000e+00, double* %.sroa.2.0..sroa_idx1.i, align 8, !dbg !27, !alias.scope !19, !noalias !42
  %30 = load [2 x double], [2 x double]* %2, align 8
  %31 = extractvalue [2 x double] %30, 0
  %32 = extractvalue [2 x double] %30, 1
  %33 = insertvalue { double, double } undef, double %31, 0
  %34 = insertvalue { double, double } %33, double %32, 1
  ret { double, double } %34
}

Crown421 avatar Sep 17 '23 17:09 Crown421

So at the moment, I don't think the CustomRules are guaranteed to apply at the higher order level, which is somethign we need to do.

wsmoses avatar Sep 18 '23 00:09 wsmoses