Turing.jl icon indicating copy to clipboard operation
Turing.jl copied to clipboard

LKJCholesky returns a Union

Open tiemvanderdeure opened this issue 1 year ago • 4 comments

LKJCholesky causes type stability problems.

MWE

using Turing, Random

@model function simple_model()
    F ~ LKJCholesky(2, 3.0)
end

model = simple_model()

@code_warntype model.f(
    model,
    Turing.VarInfo(model),
    Turing.SamplingContext(
        Random.default_rng(), Turing.SampleFromPrior(), Turing.DefaultContext()
    ),
    model.args...,
)

Returns

MethodInstance for simple_model(::DynamicPPL.Model{typeof(simple_model), (), (), (), Tuple{}, Tuple{}, DynamicPPL.DefaultContext}, ::DynamicPPL.TypedVarInfo{NamedTuple{(:cov_matrix,), Tuple{DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:cov_matrix, Setfield.IdentityLens}, Int64}, Vector{LKJCholesky{Float64}}, Vector{AbstractPPL.VarName{:cov_matrix, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}, ::DynamicPPL.SamplingContext{DynamicPPL.SampleFromPrior, DynamicPPL.DefaultContext, TaskLocalRNG})
  from simple_model(__model__::DynamicPPL.Model, __varinfo__::DynamicPPL.AbstractVarInfo, __context__::AbstractPPL.AbstractContext) @ Main Untitled-2:2
Arguments
  #self#::Core.Const(simple_model)
  __model__::Core.Const(DynamicPPL.Model{typeof(simple_model), (), (), (), Tuple{}, Tuple{}, DynamicPPL.DefaultContext}(simple_model, NamedTuple(), NamedTuple(), DynamicPPL.DefaultContext()))
  __varinfo__@_3::DynamicPPL.TypedVarInfo{NamedTuple{(:cov_matrix,), Tuple{DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:cov_matrix, Setfield.IdentityLens}, Int64}, Vector{LKJCholesky{Float64}}, Vector{AbstractPPL.VarName{:cov_matrix, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}
  __context__::Core.Const(DynamicPPL.SamplingContext{DynamicPPL.SampleFromPrior, DynamicPPL.DefaultContext, TaskLocalRNG}(TaskLocalRNG(), DynamicPPL.SampleFromPrior(), DynamicPPL.DefaultContext()))
Locals
  @_5::Union{}
  @_6::Int64
  retval#1531::Union{Cholesky{Float64, Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}}, Cholesky{Float64, Matrix{Float64}}}
  value#1528::Union{}
  value#1530::Union{Cholesky{Float64, Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}}, Cholesky{Float64, Matrix{Float64}}}
  cov_matrix::Union{Cholesky{Float64, Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}}, Cholesky{Float64, Matrix{Float64}}}
  isassumption#1527::Bool
  vn#1526::AbstractPPL.VarName{:cov_matrix, Setfield.IdentityLens}
  dist#1529::LKJCholesky{Float64}
  __varinfo__@_14::DynamicPPL.TypedVarInfo{NamedTuple{(:cov_matrix,), Tuple{DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:cov_matrix, Setfield.IdentityLens}, Int64}, Vector{LKJCholesky{Float64}}, Vector{AbstractPPL.VarName{:cov_matrix, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}
  @_15::Bool
  @_16::Bool
  @_17::Union{Cholesky{Float64, Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}}, Cholesky{Float64, Matrix{Float64}}}
  @_18::Union{Cholesky{Float64, Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}}, Cholesky{Float64, Matrix{Float64}}}
Body::Tuple{Union{Cholesky{Float64, Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}}, Cholesky{Float64, Matrix{Float64}}}, DynamicPPL.TypedVarInfo{NamedTuple{(:cov_matrix,), Tuple{DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:cov_matrix, Setfield.IdentityLens}, Int64}, Vector{LKJCholesky{Float64}}, Vector{AbstractPPL.VarName{:cov_matrix, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}}
1 ──       (__varinfo__@_14 = __varinfo__@_3)
│          Core.NewvarNode(:(@_5))
│          Core.NewvarNode(:(@_6))
│          Core.NewvarNode(:(retval#1531))
│          Core.NewvarNode(:(value#1528))
│          Core.NewvarNode(:(value#1530))
│          Core.NewvarNode(:(cov_matrix))
│          Core.NewvarNode(:(isassumption#1527))
│          (dist#1529 = Main.LKJCholesky(2, 3.0))
│    %10 = Core.apply_type(AbstractPPL.VarName, :cov_matrix)::Core.Const(AbstractPPL.VarName{:cov_matrix})
│    %11 = (%10)()::Core.Const(cov_matrix)
│          (vn#1526 = (DynamicPPL.resolve_varnames)(%11, dist#1529::Core.PartialStruct(LKJCholesky{Float64}, Any[Core.Const(2), Core.Const(3.0), Char, Float64])))
│    %13 = (DynamicPPL.contextual_isassumption)(__context__, vn#1526)::Core.Const(true)
└───       goto #8 if not %13
2 ── %15 = (DynamicPPL.inargnames)(vn#1526, __model__)::Core.Const(false)
│    %16 = !%15::Core.Const(true)
└───       goto #4 if not %16
3 ──       goto #5
4 ──       Core.Const(:((DynamicPPL.inmissings)(vn#1526, __model__)))
└───       Core.Const(:(goto %23 if not %19))
5 ┄─       (@_16 = true)
└───       goto #7
6 ──       Core.Const(:(@_16 = cov_matrix === Main.missing))
7 ┄─       (@_15 = @_16::Core.Const(true))
└───       goto #9
8 ──       Core.Const(:(@_15 = false))
9 ┄─       (isassumption#1527 = @_15::Core.Const(true))
│    %28 = (DynamicPPL.contextual_isfixed)(__context__, vn#1526)::Core.Const(false)
└───       goto #11 if not %28
10 ─       Core.Const(:((DynamicPPL.getfixed_nested)(__context__, vn#1526)))
│          Core.Const(:(cov_matrix = %30))
│          Core.Const(:(@_17 = %30))
└───       Core.Const(:(goto %63))
11 ┄       goto #13 if not isassumption#1527::Core.Const(true)
12 ─ %35 = Core.tuple(__context__)::Core.Const((DynamicPPL.SamplingContext{DynamicPPL.SampleFromPrior, DynamicPPL.DefaultContext, TaskLocalRNG}(TaskLocalRNG(), DynamicPPL.SampleFromPrior(), DynamicPPL.DefaultContext()),))
│    %36 = (DynamicPPL.check_tilde_rhs)(dist#1529::Core.PartialStruct(LKJCholesky{Float64}, Any[Core.Const(2), Core.Const(3.0), Char, Float64]))::Core.PartialStruct(LKJCholesky{Float64}, Any[Core.Const(2), Core.Const(3.0), Char, Float64])
│    %37 = (DynamicPPL.unwrap_right_vn)(%36, vn#1526)::Core.PartialStruct(Tuple{LKJCholesky{Float64}, AbstractPPL.VarName{:cov_matrix, Setfield.IdentityLens}}, Any[Core.PartialStruct(LKJCholesky{Float64}, Any[Core.Const(2), Core.Const(3.0), Char, Float64]), AbstractPPL.VarName{:cov_matrix, Setfield.IdentityLens}])
│    %38 = Core.tuple(__varinfo__@_14)::Tuple{DynamicPPL.TypedVarInfo{NamedTuple{(:cov_matrix,), Tuple{DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:cov_matrix, Setfield.IdentityLens}, Int64}, Vector{LKJCholesky{Float64}}, Vector{AbstractPPL.VarName{:cov_matrix, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}}
│    %39 = Core._apply_iterate(Base.iterate, DynamicPPL.tilde_assume!!, %35, %37, %38)::Tuple{Union{Cholesky{Float64, Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}}, Cholesky{Float64, Matrix{Float64}}}, DynamicPPL.TypedVarInfo{NamedTuple{(:cov_matrix,), Tuple{DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:cov_matrix, Setfield.IdentityLens}, Int64}, Vector{LKJCholesky{Float64}}, Vector{AbstractPPL.VarName{:cov_matrix, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}}
│    %40 = Base.indexed_iterate(%39, 1)::Core.PartialStruct(Tuple{Union{Cholesky{Float64, Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}}, Cholesky{Float64, Matrix{Float64}}}, Int64}, Any[Union{Cholesky{Float64, Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}}, Cholesky{Float64, Matrix{Float64}}}, Core.Const(2)])
│          (value#1530 = Core.getfield(%40, 1))
│          (@_6 = Core.getfield(%40, 2))
│    %43 = Base.indexed_iterate(%39, 2, @_6::Core.Const(2))::Core.PartialStruct(Tuple{DynamicPPL.TypedVarInfo{NamedTuple{(:cov_matrix,), Tuple{DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:cov_matrix, Setfield.IdentityLens}, Int64}, Vector{LKJCholesky{Float64}}, Vector{AbstractPPL.VarName{:cov_matrix, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}, Int64}, Any[DynamicPPL.TypedVarInfo{NamedTuple{(:cov_matrix,), Tuple{DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:cov_matrix, Setfield.IdentityLens}, Int64}, Vector{LKJCholesky{Float64}}, Vector{AbstractPPL.VarName{:cov_matrix, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}, Core.Const(3)])
│          (__varinfo__@_14 = Core.getfield(%43, 1))
│          (cov_matrix = value#1530)
│          (@_18 = value#1530)
└───       goto #14
13 ─       Core.Const(:((DynamicPPL.inargnames)(vn#1526, __model__)))
│          Core.Const(:(!%48))
│          Core.Const(:(goto %52 if not %49))
│          Core.Const(:(cov_matrix = (DynamicPPL.getconditioned_nested)(__context__, vn#1526)))
│          Core.Const(:((DynamicPPL.check_tilde_rhs)(dist#1529)))
│          Core.Const(:(cov_matrix))
│          Core.Const(:(vn#1526))
│          Core.Const(:((DynamicPPL.tilde_observe!!)(__context__, %52, %53, %54, __varinfo__@_14)))
│          Core.Const(:(Base.indexed_iterate(%55, 1)))
│          Core.Const(:(value#1528 = Core.getfield(%56, 1)))
│          Core.Const(:(@_5 = Core.getfield(%56, 2)))
│          Core.Const(:(Base.indexed_iterate(%55, 2, @_5)))
│          Core.Const(:(__varinfo__@_14 = Core.getfield(%59, 1)))
└───       Core.Const(:(@_18 = value#1528))
14 ┄       (@_17 = @_18)
│          (retval#1531 = @_17)
│    %64 = Core.tuple(retval#1531, __varinfo__@_14)::Tuple{Union{Cholesky{Float64, Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}}, Cholesky{Float64, Matrix{Float64}}}, DynamicPPL.TypedVarInfo{NamedTuple{(:cov_matrix,), Tuple{DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:cov_matrix, Setfield.IdentityLens}, Int64}, Vector{LKJCholesky{Float64}}, Vector{AbstractPPL.VarName{:cov_matrix, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}}
└───       return %64

tiemvanderdeure avatar Oct 17 '23 14:10 tiemvanderdeure

Just discovered the same thing, and it's absolutely killing performance of Zygote compared to ReverseDiff in a simple model I just wrote. I'm assuming that the problem is somehow in Bijectors.jl, but I can't see anything obvious.

farr avatar Oct 20 '23 18:10 farr

Ah, yes sorry this is indeed a bug. The hotfix is:

function DynamicPPL.reconstruct(dist::LKJCholesky, val::AbstractVector{<:Real})
    # HACK: Add a `Matrix` to concretize the `reshape`.
    return DynamicPPL.reconstruct(dist, Matrix(reshape(val, size(dist))))
end

This should give you type-stability.

torfjelde avatar Oct 23 '23 16:10 torfjelde

This should be better on [email protected]

torfjelde avatar Oct 25 '23 12:10 torfjelde

of Zygote compared to Reverse

Btw, I don't know if this is the cause of the perf issues for Zygote though. Zygote is very slow on Turing vs ReverseDiff for many other reasons, unfortunately :confused:

torfjelde avatar Oct 25 '23 12:10 torfjelde