Enzyme.jl icon indicating copy to clipboard operation
Enzyme.jl copied to clipboard

Enzyme doesn't work for `AdvancedVI` Part IX

Open Red-Portal opened this issue 2 weeks ago • 0 comments

Hi, finally isolated a MWE for the current error. The following error somehow seems to be related to importing the rrule, since if I don't use MixedADLogDensityProblem, the problem goes away.


using Accessors
using ChainRulesCore
using Enzyme
using Statistics
using DifferentiationInterface
using ADTypes

struct MixedADLogDensityProblem{Problem}
    problem::Problem
end

function dimension(mixedad_prob::MixedADLogDensityProblem)
    return dimension(mixedad_prob.problem)
end

function logdensity(
    mixedad_prob::MixedADLogDensityProblem, x::AbstractArray
)
    return logdensity(mixedad_prob.problem, x)
end

function ChainRulesCore.rrule(
    ::typeof(logdensity),
    mixedad_prob::MixedADLogDensityProblem,
    x::AbstractArray,
)
    ℓπ, ∇ℓπ = logdensity_and_gradient(mixedad_prob.problem, x)
    function logdensity_pullback(∂y)
        ∂x = ChainRulesCore.@thunk(∂y' * ∇ℓπ)
        return ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), ∂x
    end
    return ℓπ, logdensity_pullback
end

Enzyme.@import_rrule(typeof(logdensity), MixedADLogDensityProblem, AbstractArray)

function forward_ad(x, prob)
    logdensity(prob, x)
end

using Distributions
using Random
using LinearAlgebra
using ForwardDiff

struct SubsampledNormals{D<:Normal,F<:Real,C}
    dists::Vector{D}
    likeadj::F
    capability::C
end

function SubsampledNormals(rng::Random.AbstractRNG, n_normals::Int, capability)
    μs = randn(rng, n_normals)
    σs = ones(n_normals)
    dists = Normal.(μs, σs)
    return SubsampledNormals{eltype(dists),Float64,typeof(capability)}(
        dists, 1.0, capability
    )
end

function logdensity(m::SubsampledNormals, x)
    (; likeadj, dists) = m
    return likeadj*mapreduce(Base.Fix2(logpdf, only(x)), +, dists)
end

function logdensity_and_gradient(m::SubsampledNormals, x)
    return (
        logdensity(m, x),
        ForwardDiff.gradient(Base.Fix1(logdensity, m), x),
    )
end

function subsamplednormal(rng::Random.AbstractRNG, n_data::Int; capability::Int=1)
    cap = true
    model = SubsampledNormals(rng, n_data, cap)
    n_dims = 1
    μ_true = [mean([mean(dist) for dist in model.dists])]
    L_true = Diagonal([sqrt(1/n_data)])
    return model
end

function main()
    prob = subsamplednormal(Random.default_rng(), 16)
    prob_ad = MixedADLogDensityProblem(prob)

    y = randn(2)
    dy = zero(y)
    adtype = AutoEnzyme(;
        mode=Enzyme.set_runtime_activity(Enzyme.Reverse),
        function_annotation=Enzyme.Const,
    )
    DifferentiationInterface.value_and_gradient(
        forward_ad, adtype, y, Constant(prob_ad)
    )
end

main()

Running the code above yields:

ERROR: AssertionError: Enzyme : mismatch between innerTy LLVM.PointerType({} addrspace(10)**) and tape type LLVM.PointerType({} addrspace(10)*)
tape_idx=1
true_idx=2
isKWCall=false
kwtup=nothing
funcTy=Const{typeof(logdensity)}
isghostty(funcTy)=true
miRT=Tuple{Nothing, Nothing}
sret=nothing
returnRoots=nothing
swiftself=false
RT=Active{Float64}
rev_RT=Tuple{Nothing, Nothing}
applicablefn=true
tape=LLVM.CallInst(%23 = call {} addrspace(10)* @jl_get_nth_field_checked({} addrspace(10)* %19, i64 2), !dbg !57)
llvmf={} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32, {} addrspace(10)**)
TapeT=Tuple{Nothing, var"#logdensity_pullback#195", Val{false}}
mi=MethodInstance for logdensity(::MixedADLogDensityProblem{SubsampledNormals{Normal{Float64}, Float64, Bool}}, ::Vector{Float64})
ami=MethodInstance for EnzymeCore.EnzymeRules.augmented_primal(::EnzymeCore.EnzymeRules.RevConfigWidth{1, true, false, (false, false, false), true, false}, ::Const{typeof(logdensity)}, ::Type{Active{Float64}}, ::Const{MixedADLogDensityProblem{SubsampledNormals{Normal{Float64}, Float64, Bool}}}, ::Duplicated{Vector{Float64}})
rev_TT =Tuple{EnzymeCore.EnzymeRules.RevConfigWidth{1, true, false, (false, false, false), true, false}, Const{typeof(logdensity)}, Active{Float64}, Tuple{Nothing, var"#logdensity_pullback#195", Val{false}}, Const{MixedADLogDensityProblem{SubsampledNormals{Normal{Float64}, Float64, Bool}}}, Duplicated{Vector{Float64}}}

Stacktrace:
  [1] enzyme_custom_common_rev(forward::Bool, B::LLVM.IRBuilder, orig::LLVM.CallInst, gutils::Enzyme.Compiler.GradientUtils, normalR::Ptr{…}, shadowR::Ptr{…}, tape::LLVM.CallInst)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/i91IH/src/rules/customrules.jl:1635
  [2] enzyme_custom_rev(B::LLVM.IRBuilder, orig::LLVM.CallInst, gutils::Enzyme.Compiler.GradientUtils, tape::Union{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/i91IH/src/rules/customrules.jl:2148
  [3] enzyme_custom_rev_cfunc(B::Ptr{…}, OrigCI::Ptr{…}, gutils::Ptr{…}, tape::Ptr{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/i91IH/src/rules/llvmrules.jl:48
  [4] EnzymeCreatePrimalAndGradient(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{…}, TA::Enzyme.TypeAnalysis, returnValue::Bool, dretUsed::Bool, mode::Enzyme.API.CDerivativeMode, runtimeActivity::Bool, strongZero::Bool, width::Int64, additionalArg::Ptr{…}, forceAnonymousTape::Bool, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{…}, augmented::Ptr{…}, atomicAdd::Bool)
    @ Enzyme.API ~/.julia/packages/Enzyme/i91IH/src/api.jl:270
  [5] macro expansion
    @ ~/.julia/packages/Enzyme/i91IH/src/compiler.jl:2747 [inlined]
  [6] macro expansion
    @ ~/.julia/packages/LLVM/iza6e/src/base.jl:97 [inlined]
  [7] enzyme!(job::GPUCompiler.CompilerJob{…}, interp::Enzyme.Compiler.Interpreter.EnzymeInterpreter{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::NTuple{…} where N, returnPrimal::Bool, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…}, removedRoots::Set{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/i91IH/src/compiler.jl:2620
  [8] compile_unhooked(output::Symbol, job::GPUCompiler.CompilerJob{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/i91IH/src/compiler.jl:5812
  [9] compile(target::Symbol, job::GPUCompiler.CompilerJob; kwargs::@Kwargs{})
    @ GPUCompiler ~/.julia/packages/GPUCompiler/j4HFa/src/driver.jl:67
 [10] compile
    @ ~/.julia/packages/GPUCompiler/j4HFa/src/driver.jl:55 [inlined]
 [11] _thunk(job::GPUCompiler.CompilerJob{…}, postopt::Bool)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/i91IH/src/compiler.jl:6725
 [12] _thunk
    @ ~/.julia/packages/Enzyme/i91IH/src/compiler.jl:6723 [inlined]
 [13] cached_compilation
    @ ~/.julia/packages/Enzyme/i91IH/src/compiler.jl:6781 [inlined]
 [14] thunkbase(mi::Core.MethodInstance, World::UInt64, FA::Type{…}, A::Type{…}, TT::Type, Mode::Enzyme.API.CDerivativeMode, width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, StrongZero::Bool, edges::Vector{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/i91IH/src/compiler.jl:6897
 [15] thunk_generator(world::UInt64, source::Union{…}, FA::Type, A::Type, TT::Type, Mode::Enzyme.API.CDerivativeMode, Width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, StrongZero::Bool, self::Any, fakeworld::Any, fa::Type, a::Type, tt::Type, mode::Type, width::Type, modifiedbetween::Type, returnprimal::Type, shadowinit::Type, abi::Type, erriffuncwritten::Type, runtimeactivity::Type, strongzero::Type)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/i91IH/src/compiler.jl:7041
 [16] autodiff
    @ ~/.julia/packages/Enzyme/i91IH/src/Enzyme.jl:502 [inlined]
 [17] macro expansion
    @ ~/.julia/packages/Enzyme/i91IH/src/sugar.jl:287 [inlined]
 [18] gradient
    @ ~/.julia/packages/Enzyme/i91IH/src/sugar.jl:274 [inlined]
 [19] value_and_gradient(f::typeof(forward_ad), prep::DifferentiationInterfaceEnzymeExt.EnzymeGradientPrep{…}, backend::AutoEnzyme{…}, x::Vector{…}, contexts::Constant{…})
    @ DifferentiationInterfaceEnzymeExt ~/.julia/packages/DifferentiationInterface/4n6vR/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl:253
 [20] value_and_gradient
    @ ~/.julia/packages/DifferentiationInterface/4n6vR/src/first_order/gradient.jl:37 [inlined]
 [21] main()
    @ Main ./REPL[207]:11
 [22] top-level scope
    @ REPL[208]:1
Some type information was truncated. Use `show(err)` to see complete types.

Red-Portal avatar Dec 12 '25 00:12 Red-Portal