Enzyme.jl
Enzyme.jl copied to clipboard
Enzyme doesn't work for `AdvancedVI` Part IX
Hi, finally isolated a MWE for the current error. The following error somehow seems to be related to importing the rrule, since if I don't use MixedADLogDensityProblem, the problem goes away.
using Accessors
using ChainRulesCore
using Enzyme
using Statistics
using DifferentiationInterface
using ADTypes
struct MixedADLogDensityProblem{Problem}
problem::Problem
end
function dimension(mixedad_prob::MixedADLogDensityProblem)
return dimension(mixedad_prob.problem)
end
function logdensity(
mixedad_prob::MixedADLogDensityProblem, x::AbstractArray
)
return logdensity(mixedad_prob.problem, x)
end
function ChainRulesCore.rrule(
::typeof(logdensity),
mixedad_prob::MixedADLogDensityProblem,
x::AbstractArray,
)
ℓπ, ∇ℓπ = logdensity_and_gradient(mixedad_prob.problem, x)
function logdensity_pullback(∂y)
∂x = ChainRulesCore.@thunk(∂y' * ∇ℓπ)
return ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), ∂x
end
return ℓπ, logdensity_pullback
end
Enzyme.@import_rrule(typeof(logdensity), MixedADLogDensityProblem, AbstractArray)
function forward_ad(x, prob)
logdensity(prob, x)
end
using Distributions
using Random
using LinearAlgebra
using ForwardDiff
struct SubsampledNormals{D<:Normal,F<:Real,C}
dists::Vector{D}
likeadj::F
capability::C
end
function SubsampledNormals(rng::Random.AbstractRNG, n_normals::Int, capability)
μs = randn(rng, n_normals)
σs = ones(n_normals)
dists = Normal.(μs, σs)
return SubsampledNormals{eltype(dists),Float64,typeof(capability)}(
dists, 1.0, capability
)
end
function logdensity(m::SubsampledNormals, x)
(; likeadj, dists) = m
return likeadj*mapreduce(Base.Fix2(logpdf, only(x)), +, dists)
end
function logdensity_and_gradient(m::SubsampledNormals, x)
return (
logdensity(m, x),
ForwardDiff.gradient(Base.Fix1(logdensity, m), x),
)
end
function subsamplednormal(rng::Random.AbstractRNG, n_data::Int; capability::Int=1)
cap = true
model = SubsampledNormals(rng, n_data, cap)
n_dims = 1
μ_true = [mean([mean(dist) for dist in model.dists])]
L_true = Diagonal([sqrt(1/n_data)])
return model
end
function main()
prob = subsamplednormal(Random.default_rng(), 16)
prob_ad = MixedADLogDensityProblem(prob)
y = randn(2)
dy = zero(y)
adtype = AutoEnzyme(;
mode=Enzyme.set_runtime_activity(Enzyme.Reverse),
function_annotation=Enzyme.Const,
)
DifferentiationInterface.value_and_gradient(
forward_ad, adtype, y, Constant(prob_ad)
)
end
main()
Running the code above yields:
ERROR: AssertionError: Enzyme : mismatch between innerTy LLVM.PointerType({} addrspace(10)**) and tape type LLVM.PointerType({} addrspace(10)*)
tape_idx=1
true_idx=2
isKWCall=false
kwtup=nothing
funcTy=Const{typeof(logdensity)}
isghostty(funcTy)=true
miRT=Tuple{Nothing, Nothing}
sret=nothing
returnRoots=nothing
swiftself=false
RT=Active{Float64}
rev_RT=Tuple{Nothing, Nothing}
applicablefn=true
tape=LLVM.CallInst(%23 = call {} addrspace(10)* @jl_get_nth_field_checked({} addrspace(10)* %19, i64 2), !dbg !57)
llvmf={} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32, {} addrspace(10)**)
TapeT=Tuple{Nothing, var"#logdensity_pullback#195", Val{false}}
mi=MethodInstance for logdensity(::MixedADLogDensityProblem{SubsampledNormals{Normal{Float64}, Float64, Bool}}, ::Vector{Float64})
ami=MethodInstance for EnzymeCore.EnzymeRules.augmented_primal(::EnzymeCore.EnzymeRules.RevConfigWidth{1, true, false, (false, false, false), true, false}, ::Const{typeof(logdensity)}, ::Type{Active{Float64}}, ::Const{MixedADLogDensityProblem{SubsampledNormals{Normal{Float64}, Float64, Bool}}}, ::Duplicated{Vector{Float64}})
rev_TT =Tuple{EnzymeCore.EnzymeRules.RevConfigWidth{1, true, false, (false, false, false), true, false}, Const{typeof(logdensity)}, Active{Float64}, Tuple{Nothing, var"#logdensity_pullback#195", Val{false}}, Const{MixedADLogDensityProblem{SubsampledNormals{Normal{Float64}, Float64, Bool}}}, Duplicated{Vector{Float64}}}
Stacktrace:
[1] enzyme_custom_common_rev(forward::Bool, B::LLVM.IRBuilder, orig::LLVM.CallInst, gutils::Enzyme.Compiler.GradientUtils, normalR::Ptr{…}, shadowR::Ptr{…}, tape::LLVM.CallInst)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/i91IH/src/rules/customrules.jl:1635
[2] enzyme_custom_rev(B::LLVM.IRBuilder, orig::LLVM.CallInst, gutils::Enzyme.Compiler.GradientUtils, tape::Union{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/i91IH/src/rules/customrules.jl:2148
[3] enzyme_custom_rev_cfunc(B::Ptr{…}, OrigCI::Ptr{…}, gutils::Ptr{…}, tape::Ptr{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/i91IH/src/rules/llvmrules.jl:48
[4] EnzymeCreatePrimalAndGradient(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{…}, TA::Enzyme.TypeAnalysis, returnValue::Bool, dretUsed::Bool, mode::Enzyme.API.CDerivativeMode, runtimeActivity::Bool, strongZero::Bool, width::Int64, additionalArg::Ptr{…}, forceAnonymousTape::Bool, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{…}, augmented::Ptr{…}, atomicAdd::Bool)
@ Enzyme.API ~/.julia/packages/Enzyme/i91IH/src/api.jl:270
[5] macro expansion
@ ~/.julia/packages/Enzyme/i91IH/src/compiler.jl:2747 [inlined]
[6] macro expansion
@ ~/.julia/packages/LLVM/iza6e/src/base.jl:97 [inlined]
[7] enzyme!(job::GPUCompiler.CompilerJob{…}, interp::Enzyme.Compiler.Interpreter.EnzymeInterpreter{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::NTuple{…} where N, returnPrimal::Bool, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…}, removedRoots::Set{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/i91IH/src/compiler.jl:2620
[8] compile_unhooked(output::Symbol, job::GPUCompiler.CompilerJob{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/i91IH/src/compiler.jl:5812
[9] compile(target::Symbol, job::GPUCompiler.CompilerJob; kwargs::@Kwargs{})
@ GPUCompiler ~/.julia/packages/GPUCompiler/j4HFa/src/driver.jl:67
[10] compile
@ ~/.julia/packages/GPUCompiler/j4HFa/src/driver.jl:55 [inlined]
[11] _thunk(job::GPUCompiler.CompilerJob{…}, postopt::Bool)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/i91IH/src/compiler.jl:6725
[12] _thunk
@ ~/.julia/packages/Enzyme/i91IH/src/compiler.jl:6723 [inlined]
[13] cached_compilation
@ ~/.julia/packages/Enzyme/i91IH/src/compiler.jl:6781 [inlined]
[14] thunkbase(mi::Core.MethodInstance, World::UInt64, FA::Type{…}, A::Type{…}, TT::Type, Mode::Enzyme.API.CDerivativeMode, width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, StrongZero::Bool, edges::Vector{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/i91IH/src/compiler.jl:6897
[15] thunk_generator(world::UInt64, source::Union{…}, FA::Type, A::Type, TT::Type, Mode::Enzyme.API.CDerivativeMode, Width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, StrongZero::Bool, self::Any, fakeworld::Any, fa::Type, a::Type, tt::Type, mode::Type, width::Type, modifiedbetween::Type, returnprimal::Type, shadowinit::Type, abi::Type, erriffuncwritten::Type, runtimeactivity::Type, strongzero::Type)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/i91IH/src/compiler.jl:7041
[16] autodiff
@ ~/.julia/packages/Enzyme/i91IH/src/Enzyme.jl:502 [inlined]
[17] macro expansion
@ ~/.julia/packages/Enzyme/i91IH/src/sugar.jl:287 [inlined]
[18] gradient
@ ~/.julia/packages/Enzyme/i91IH/src/sugar.jl:274 [inlined]
[19] value_and_gradient(f::typeof(forward_ad), prep::DifferentiationInterfaceEnzymeExt.EnzymeGradientPrep{…}, backend::AutoEnzyme{…}, x::Vector{…}, contexts::Constant{…})
@ DifferentiationInterfaceEnzymeExt ~/.julia/packages/DifferentiationInterface/4n6vR/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl:253
[20] value_and_gradient
@ ~/.julia/packages/DifferentiationInterface/4n6vR/src/first_order/gradient.jl:37 [inlined]
[21] main()
@ Main ./REPL[207]:11
[22] top-level scope
@ REPL[208]:1
Some type information was truncated. Use `show(err)` to see complete types.