Turing.jl
Turing.jl copied to clipboard
LKJCholesky returns a Union
LKJCholesky causes type stability problems.
MWE
using Turing, Random
@model function simple_model()
F ~ LKJCholesky(2, 3.0)
end
model = simple_model()
@code_warntype model.f(
model,
Turing.VarInfo(model),
Turing.SamplingContext(
Random.default_rng(), Turing.SampleFromPrior(), Turing.DefaultContext()
),
model.args...,
)
Returns
MethodInstance for simple_model(::DynamicPPL.Model{typeof(simple_model), (), (), (), Tuple{}, Tuple{}, DynamicPPL.DefaultContext}, ::DynamicPPL.TypedVarInfo{NamedTuple{(:cov_matrix,), Tuple{DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:cov_matrix, Setfield.IdentityLens}, Int64}, Vector{LKJCholesky{Float64}}, Vector{AbstractPPL.VarName{:cov_matrix, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}, ::DynamicPPL.SamplingContext{DynamicPPL.SampleFromPrior, DynamicPPL.DefaultContext, TaskLocalRNG})
from simple_model(__model__::DynamicPPL.Model, __varinfo__::DynamicPPL.AbstractVarInfo, __context__::AbstractPPL.AbstractContext) @ Main Untitled-2:2
Arguments
#self#::Core.Const(simple_model)
__model__::Core.Const(DynamicPPL.Model{typeof(simple_model), (), (), (), Tuple{}, Tuple{}, DynamicPPL.DefaultContext}(simple_model, NamedTuple(), NamedTuple(), DynamicPPL.DefaultContext()))
__varinfo__@_3::DynamicPPL.TypedVarInfo{NamedTuple{(:cov_matrix,), Tuple{DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:cov_matrix, Setfield.IdentityLens}, Int64}, Vector{LKJCholesky{Float64}}, Vector{AbstractPPL.VarName{:cov_matrix, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}
__context__::Core.Const(DynamicPPL.SamplingContext{DynamicPPL.SampleFromPrior, DynamicPPL.DefaultContext, TaskLocalRNG}(TaskLocalRNG(), DynamicPPL.SampleFromPrior(), DynamicPPL.DefaultContext()))
Locals
@_5::Union{}
@_6::Int64
retval#1531::Union{Cholesky{Float64, Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}}, Cholesky{Float64, Matrix{Float64}}}
value#1528::Union{}
value#1530::Union{Cholesky{Float64, Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}}, Cholesky{Float64, Matrix{Float64}}}
cov_matrix::Union{Cholesky{Float64, Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}}, Cholesky{Float64, Matrix{Float64}}}
isassumption#1527::Bool
vn#1526::AbstractPPL.VarName{:cov_matrix, Setfield.IdentityLens}
dist#1529::LKJCholesky{Float64}
__varinfo__@_14::DynamicPPL.TypedVarInfo{NamedTuple{(:cov_matrix,), Tuple{DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:cov_matrix, Setfield.IdentityLens}, Int64}, Vector{LKJCholesky{Float64}}, Vector{AbstractPPL.VarName{:cov_matrix, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}
@_15::Bool
@_16::Bool
@_17::Union{Cholesky{Float64, Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}}, Cholesky{Float64, Matrix{Float64}}}
@_18::Union{Cholesky{Float64, Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}}, Cholesky{Float64, Matrix{Float64}}}
Body::Tuple{Union{Cholesky{Float64, Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}}, Cholesky{Float64, Matrix{Float64}}}, DynamicPPL.TypedVarInfo{NamedTuple{(:cov_matrix,), Tuple{DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:cov_matrix, Setfield.IdentityLens}, Int64}, Vector{LKJCholesky{Float64}}, Vector{AbstractPPL.VarName{:cov_matrix, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}}
1 ── (__varinfo__@_14 = __varinfo__@_3)
│ Core.NewvarNode(:(@_5))
│ Core.NewvarNode(:(@_6))
│ Core.NewvarNode(:(retval#1531))
│ Core.NewvarNode(:(value#1528))
│ Core.NewvarNode(:(value#1530))
│ Core.NewvarNode(:(cov_matrix))
│ Core.NewvarNode(:(isassumption#1527))
│ (dist#1529 = Main.LKJCholesky(2, 3.0))
│ %10 = Core.apply_type(AbstractPPL.VarName, :cov_matrix)::Core.Const(AbstractPPL.VarName{:cov_matrix})
│ %11 = (%10)()::Core.Const(cov_matrix)
│ (vn#1526 = (DynamicPPL.resolve_varnames)(%11, dist#1529::Core.PartialStruct(LKJCholesky{Float64}, Any[Core.Const(2), Core.Const(3.0), Char, Float64])))
│ %13 = (DynamicPPL.contextual_isassumption)(__context__, vn#1526)::Core.Const(true)
└─── goto #8 if not %13
2 ── %15 = (DynamicPPL.inargnames)(vn#1526, __model__)::Core.Const(false)
│ %16 = !%15::Core.Const(true)
└─── goto #4 if not %16
3 ── goto #5
4 ── Core.Const(:((DynamicPPL.inmissings)(vn#1526, __model__)))
└─── Core.Const(:(goto %23 if not %19))
5 ┄─ (@_16 = true)
└─── goto #7
6 ── Core.Const(:(@_16 = cov_matrix === Main.missing))
7 ┄─ (@_15 = @_16::Core.Const(true))
└─── goto #9
8 ── Core.Const(:(@_15 = false))
9 ┄─ (isassumption#1527 = @_15::Core.Const(true))
│ %28 = (DynamicPPL.contextual_isfixed)(__context__, vn#1526)::Core.Const(false)
└─── goto #11 if not %28
10 ─ Core.Const(:((DynamicPPL.getfixed_nested)(__context__, vn#1526)))
│ Core.Const(:(cov_matrix = %30))
│ Core.Const(:(@_17 = %30))
└─── Core.Const(:(goto %63))
11 ┄ goto #13 if not isassumption#1527::Core.Const(true)
12 ─ %35 = Core.tuple(__context__)::Core.Const((DynamicPPL.SamplingContext{DynamicPPL.SampleFromPrior, DynamicPPL.DefaultContext, TaskLocalRNG}(TaskLocalRNG(), DynamicPPL.SampleFromPrior(), DynamicPPL.DefaultContext()),))
│ %36 = (DynamicPPL.check_tilde_rhs)(dist#1529::Core.PartialStruct(LKJCholesky{Float64}, Any[Core.Const(2), Core.Const(3.0), Char, Float64]))::Core.PartialStruct(LKJCholesky{Float64}, Any[Core.Const(2), Core.Const(3.0), Char, Float64])
│ %37 = (DynamicPPL.unwrap_right_vn)(%36, vn#1526)::Core.PartialStruct(Tuple{LKJCholesky{Float64}, AbstractPPL.VarName{:cov_matrix, Setfield.IdentityLens}}, Any[Core.PartialStruct(LKJCholesky{Float64}, Any[Core.Const(2), Core.Const(3.0), Char, Float64]), AbstractPPL.VarName{:cov_matrix, Setfield.IdentityLens}])
│ %38 = Core.tuple(__varinfo__@_14)::Tuple{DynamicPPL.TypedVarInfo{NamedTuple{(:cov_matrix,), Tuple{DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:cov_matrix, Setfield.IdentityLens}, Int64}, Vector{LKJCholesky{Float64}}, Vector{AbstractPPL.VarName{:cov_matrix, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}}
│ %39 = Core._apply_iterate(Base.iterate, DynamicPPL.tilde_assume!!, %35, %37, %38)::Tuple{Union{Cholesky{Float64, Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}}, Cholesky{Float64, Matrix{Float64}}}, DynamicPPL.TypedVarInfo{NamedTuple{(:cov_matrix,), Tuple{DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:cov_matrix, Setfield.IdentityLens}, Int64}, Vector{LKJCholesky{Float64}}, Vector{AbstractPPL.VarName{:cov_matrix, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}}
│ %40 = Base.indexed_iterate(%39, 1)::Core.PartialStruct(Tuple{Union{Cholesky{Float64, Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}}, Cholesky{Float64, Matrix{Float64}}}, Int64}, Any[Union{Cholesky{Float64, Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}}, Cholesky{Float64, Matrix{Float64}}}, Core.Const(2)])
│ (value#1530 = Core.getfield(%40, 1))
│ (@_6 = Core.getfield(%40, 2))
│ %43 = Base.indexed_iterate(%39, 2, @_6::Core.Const(2))::Core.PartialStruct(Tuple{DynamicPPL.TypedVarInfo{NamedTuple{(:cov_matrix,), Tuple{DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:cov_matrix, Setfield.IdentityLens}, Int64}, Vector{LKJCholesky{Float64}}, Vector{AbstractPPL.VarName{:cov_matrix, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}, Int64}, Any[DynamicPPL.TypedVarInfo{NamedTuple{(:cov_matrix,), Tuple{DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:cov_matrix, Setfield.IdentityLens}, Int64}, Vector{LKJCholesky{Float64}}, Vector{AbstractPPL.VarName{:cov_matrix, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}, Core.Const(3)])
│ (__varinfo__@_14 = Core.getfield(%43, 1))
│ (cov_matrix = value#1530)
│ (@_18 = value#1530)
└─── goto #14
13 ─ Core.Const(:((DynamicPPL.inargnames)(vn#1526, __model__)))
│ Core.Const(:(!%48))
│ Core.Const(:(goto %52 if not %49))
│ Core.Const(:(cov_matrix = (DynamicPPL.getconditioned_nested)(__context__, vn#1526)))
│ Core.Const(:((DynamicPPL.check_tilde_rhs)(dist#1529)))
│ Core.Const(:(cov_matrix))
│ Core.Const(:(vn#1526))
│ Core.Const(:((DynamicPPL.tilde_observe!!)(__context__, %52, %53, %54, __varinfo__@_14)))
│ Core.Const(:(Base.indexed_iterate(%55, 1)))
│ Core.Const(:(value#1528 = Core.getfield(%56, 1)))
│ Core.Const(:(@_5 = Core.getfield(%56, 2)))
│ Core.Const(:(Base.indexed_iterate(%55, 2, @_5)))
│ Core.Const(:(__varinfo__@_14 = Core.getfield(%59, 1)))
└─── Core.Const(:(@_18 = value#1528))
14 ┄ (@_17 = @_18)
│ (retval#1531 = @_17)
│ %64 = Core.tuple(retval#1531, __varinfo__@_14)::Tuple{Union{Cholesky{Float64, Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}}, Cholesky{Float64, Matrix{Float64}}}, DynamicPPL.TypedVarInfo{NamedTuple{(:cov_matrix,), Tuple{DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:cov_matrix, Setfield.IdentityLens}, Int64}, Vector{LKJCholesky{Float64}}, Vector{AbstractPPL.VarName{:cov_matrix, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}}
└─── return %64
Just discovered the same thing, and it's absolutely killing performance of Zygote compared to ReverseDiff in a simple model I just wrote. I'm assuming that the problem is somehow in Bijectors.jl, but I can't see anything obvious.
Ah, yes sorry this is indeed a bug. The hotfix is:
function DynamicPPL.reconstruct(dist::LKJCholesky, val::AbstractVector{<:Real})
# HACK: Add a `Matrix` to concretize the `reshape`.
return DynamicPPL.reconstruct(dist, Matrix(reshape(val, size(dist))))
end
This should give you type-stability.
This should be better on [email protected]
of Zygote compared to Reverse
Btw, I don't know if this is the cause of the perf issues for Zygote though. Zygote is very slow on Turing vs ReverseDiff for many other reasons, unfortunately :confused: