DynamicPPL.jl
DynamicPPL.jl copied to clipboard
Need alternative to `NamedTuple` for `SimpleVarInfo`
Problem
Now that we properly support usage of different sizes in the underlying storage of the varinfo after linking, the current usage of NamedTuple
for both the "ground truth" in TestUtils, e.g.
https://github.com/TuringLang/DynamicPPL.jl/blob/549d9b150078eaffaf91f324d439c133e4314303/src/test_utils.jl#L33-L37
and in SimpleVarInfo
, makes less sense than it did before.
To see why, let's consider the following example:
julia> using DynamicPPL, Distributions
julia> @model function demo()
x = Vector{Float64}(undef, 5)
x[1] ~ Normal()
x[2:3] ~ Dirichlet([1.0, 2.0])
return (x=x,)
end
demo (generic function with 2 methods)
julia> model = demo();
julia> nt = model()
(x = [-0.08084553378437927, 0.6662187241949805, 0.3337812758050194, 6.93842891015994e-310, 6.93842829167434e-310],)
julia> # Construct `SimpleVarInfo` from `nt`.
vi = SimpleVarInfo(nt)
SimpleVarInfo((x = [-0.08084553378437927, 0.6662187241949805, 0.3337812758050194, 6.93842891015994e-310, 6.93842829167434e-310],), 0.0)
julia> vn = @varname(x[2:3])
x[2:3]
julia> # (✓) Everything works nicely
vi[vn]
2-element Vector{Float64}:
0.6662187241949805
0.3337812758050194
julia> # Now we link it!
vi_linked = DynamicPPL.link!!(vi, model);
ERROR: DimensionMismatch: tried to assign 1 elements to 2 destinations
Stacktrace:
[1] throw_setindex_mismatch(X::Vector{Float64}, I::Tuple{Int64})
@ Base ./indices.jl:191
[2] setindex_shape_check
@ ./indices.jl:245 [inlined]
[3] setindex!
@ ./array.jl:994 [inlined]
[4] _setindex!
@ ~/.julia/packages/BangBang/FUkah/src/base.jl:480 [inlined]
[5] may
@ ~/.julia/packages/BangBang/FUkah/src/core.jl:9 [inlined]
[6] setindex!!
@ ~/.julia/packages/BangBang/FUkah/src/base.jl:478 [inlined]
[7] set(obj::Vector{Float64}, lens::BangBang.SetfieldImpl.Lens!!{Setfield.IndexLens{Tuple{UnitRange{Int64}}}}, value::Vector{Float64})
@ BangBang.SetfieldImpl ~/.julia/packages/BangBang/FUkah/src/setfield.jl:34
[8] set
@ ~/.julia/packages/Setfield/PdKfV/src/lens.jl:188 [inlined]
[9] set
@ ~/.julia/packages/BangBang/FUkah/src/setfield.jl:17 [inlined]
[10] set!!
@ /drive-2/Projects/public/DynamicPPL.jl/src/utils.jl:354 [inlined]
[11] macro expansion
@ ~/.julia/packages/Setfield/PdKfV/src/sugar.jl:197 [inlined]
[12] setindex!!
@ /drive-2/Projects/public/DynamicPPL.jl/src/simple_varinfo.jl:339 [inlined]
[13] tilde_assume(#unused#::DynamicPPL.DynamicTransformationContext{false}, right::Dirichlet{Float64, Vector{Float64}, Float64}, vn::VarName{:x, Setfield.IndexLens{Tuple{UnitRange{Int64}}}}, vi::SimpleVarInfo{NamedTuple{(:x,), Tuple{Vector{Float64}}}, Float64, DynamicPPL.NoTransformation})
@ DynamicPPL /drive-2/Projects/public/DynamicPPL.jl/src/transforming.jl:19
[14] tilde_assume!!
@ /drive-2/Projects/public/DynamicPPL.jl/src/context_implementations.jl:117 [inlined]
[15] demo(__model__::Model{typeof(demo), (), (), (), Tuple{}, Tuple{}, DefaultContext}, __varinfo__::SimpleVarInfo{NamedTuple{(:x,), Tuple{Vector{Float64}}}, Float64, DynamicPPL.NoTransformation}, __context__::DynamicPPL.DynamicTransformationContext{false})
@ Main ./REPL[47]:4
[16] _evaluate!!
@ /drive-2/Projects/public/DynamicPPL.jl/src/model.jl:963 [inlined]
[17] evaluate_threadunsafe!!
@ /drive-2/Projects/public/DynamicPPL.jl/src/model.jl:936 [inlined]
[18] evaluate!!
@ /drive-2/Projects/public/DynamicPPL.jl/src/model.jl:889 [inlined]
[19] link!!
@ /drive-2/Projects/public/DynamicPPL.jl/src/transforming.jl:86 [inlined]
[20] link!!
@ /drive-2/Projects/public/DynamicPPL.jl/src/abstract_varinfo.jl:384 [inlined]
[21] link!!(vi::SimpleVarInfo{NamedTuple{(:x,), Tuple{Vector{Float64}}}, Float64, DynamicPPL.NoTransformation}, model::Model{typeof(demo), (), (), (), Tuple{}, Tuple{}, DefaultContext})
@ DynamicPPL /drive-2/Projects/public/DynamicPPL.jl/src/abstract_varinfo.jl:378
[22] top-level scope
@ REPL[53]:2
The issue here can really just be boiled down to the fact that we're trying to use the varname
julia> vn
x[2:3]
to index a NamedTuple
which is after the transformation represented by a 1-length vector rather than 2-length vector.
In contrast, SimpeVarInfo{<:AbstractDict}
will work just fine because here each varname gets its own entry:
julia> # Construct `SimpleVarInfo` using a dict now.
vi = SimpleVarInfo(rand(OrderedDict, model))
SimpleVarInfo(OrderedDict{Any, Any}(x[1] => -0.12337922752695839, x[2:3] => [0.7836009759179734, 0.21639902408202646]), 0)
julia> # (✓) Everything works nicely
vi[vn]
2-element Vector{Float64}:
0.7836009759179734
0.21639902408202646
julia> # Now we link it!
vi_linked = DynamicPPL.link!!(vi, model);
julia> # (✓) Everything works nicely
vi_linked[vn]
1-element Vector{Float64}:
1.2867758943235161
"Luckily" it has always been the plan that SimpleVarInfo
should be able to use different underlying representations fairly easily, e.g. I've successfully used it with ComponentVector
from ComponentArrays.jl many times before. And so we should probably find a more flexible default representation for SimpleVarInfo
that can be used in more cases.
Solution
Option 1: Use OrderedDict
by default
This one is obviously not great becuase of performance reasons, but it will "just work" in all cases and it's very simple to reason about.
Option 2: Dict-like flattened representation
In an ideal world, the underlying representation of the values in a varinfo would have the following properties:
- It's type-stable, when possible.
- It's contiguous in memory, when possible.
- It's indexable by
VarName
.
Something like an OrderedDict
fails in two regards:
- It's not contiguous in memory.
- Type-stability is not guaranteed, unless we create a dictionary for each eltype or something similar.
Current Metadata
used by VarInfo
The Metadata
type in VarInfo
is a good example of something that satisfies all three properties (of course, the "when possible" in Property (1) is not concrete, but VarInfo
uses a NamedTuple
of Metadata
to achieve this in most common use-cases).
As a reminder, here is what Metadata
looks like:
https://github.com/TuringLang/DynamicPPL.jl/blob/e05bb0935a1e1a06027c603cc04c20f23195a6c4/src/varinfo.jl#L39-L72
Most importantly for a dict-like storage of values, are the following lines:
https://github.com/TuringLang/DynamicPPL.jl/blob/e05bb0935a1e1a06027c603cc04c20f23195a6c4/src/varinfo.jl#L46-L58
With this, it's fairly easy to implement nice indexing behavior for VarInfo
. Here's a simple sketch of what a getindex
could look like for Metadata
:
function Base.getindex(metadata::Metadata, varname::VarName)
# Get the index for this `varname`
idx = metadata.idcs[varname]
# Get the range for this `varname`
r = metadata.ranges[idx]
# Extract the value.
return metadata.values[r]
end
This is effectively the getval
currently implemented:
https://github.com/TuringLang/DynamicPPL.jl/blob/e05bb0935a1e1a06027c603cc04c20f23195a6c4/src/varinfo.jl#L318C44-L318C44
This then results in a Vector
of the flattened representation of vn
.
Our current implementation of Base.getindex
for VarInfo
then contains more complexity to convert the Vector
back into the original form expected by corresponding distribution, and it's usage looks like
varinfo[varname, dist]
Since the dist
is also stored in the Metadata
, the above in fact works the same if you do varinfo[varname]
if varinfo isa VarInfo
and not a SimpleVarInfo
. But, as have been discussed many times before, this is not great because it doesn't properly handle dynamics constraints, etc.; we want to use the dist
at the place of index, not from the construction of the varinfo
.
Nonetheless, value-storage part of Metadata
arguably proves quite a nice way to store values in a dict-like way while satisfying the three properties above.
So why not just use Metadata
?
Well, we probably should be. But if we're doing so, we should probably simplify its structure quite a bit.
For example, should we drop the following fields?
-
dists
: As mentioned, this is often not the correct thing to use. -
gids
: This is used by the Gibbs sampler, and will at some point not be of use anymore since we now have ways of programmatically conditioning and deconditioning models. -
orders
: Only used by particle methods to keep track of the number of observe statements hit. This should probably either be moved somewhere else or at least not be hardcoded into the "main" dict-like object. -
flags
: this might be generally useful, but the flags current used (istrans
anddelete
) are no longer that useful (istrans
should be replaced by explicit transformations, as is done inSimpleVarInfo
, anddelete
should also no longer be needed as now have a clear way of indicating whether we're running a model in "sampling mode" or not usingSamplingContext
).
But the problem of doing this, is that we'll break a lot of code currently dependent on VarInfo
functioning as is.
This is also the main reason why we introduced SimpleVarInfo
: to allow us to create simpler and different representations of varinfos without breaking existing code.
So what should we do?
For now, it might be a good idea to just introduce a type very similar to Metadata
but simpler in its form, i.e. mainly just a value container.
We could either implement out own, or we could see if there are existing implementations in the ecosystem that could benefit us, e.g. Dictionaries.jl seems like it might be suitable.
Dictionaries.jl will unfortunately not give us contiguous memory (at least not by default):
julia> using Dictionaries
julia> vnv = Dictionary{VarName}(OrderedDict(@varname(a) => [1.0,2,3], @varname(b) => [4,5,6.]))
2-element Dictionary{VarName, Vector{Float64}}
a │ [1.0, 2.0, 3.0]
b │ [4.0, 5.0, 6.0]
julia> vnv.values
2-element Vector{Vector{Float64}}:
[1.0, 2.0, 3.0]
[4.0, 5.0, 6.0]
How much does contiguous memory actually matter? I would imagine CPU cache fills in a lot of the performance gap.
Well, it depends. Generally speaking, I'm expecting the access patterns here to be views mainly, and so if you then reshape these into larger matrices which is then used for downstream computations, then we'd expect it to make a noticeable difference, no?
But no matter what we do, we should benchmark these things at some point.
And rolling our own shouldn't be too difficult.
For example, we could use something like
"""
VarNameDict
A `VarNameDict` is a vector-like collection of values that can be indexed by `VarName`.
This is basically like a `OrderedDict{<:VarName}` but ensures that the underlying values
are stored contiguously in memory.
"""
struct VarNameDict{K,T,V<:AbstractVector{T},D<:AbstractDict{K}} <: AbstractDict{K,T}
values::V
varname_to_ranges::D
end
function VarNameDict(dict::OrderedDict)
offset = 0
ranges = map(values(dict)) do x
r = (offset + 1):(offset + length(x))
offset = r[end]
r
end
vals = mapreduce(DynamicPPL.vectorize, vcat, values(dict))
return VarNameDict(vals, OrderedDict(zip(keys(dict), ranges)))
end
# Dict-like functionality.
Base.keys(vnd::VarNameDict) = keys(vnd.varname_to_ranges)
Base.values(vnd::VarNameDict) = vnd.values
Base.length(vnd::VarNameDict) = length(vnd.values)
Base.getindex(vnd::VarNameDict, i) = getindex(vnd.values, i)
Base.setindex!(vnd::VarNameDict, val, i) = setindex!(vnd.values, val, i)
function nextrange(vnd::VarNameDict, x)
n = length(vnd)
return n + 1:n + length(x)
end
function Base.getindex(vnd::VarNameDict, vn::VarName)
return getindex(vnd.values, vnd.varname_to_ranges[vn])
end
function Base.setindex!(vnd::VarNameDict, val, vn::VarName)
# If we don't have `vn` in the dictionary, then we need to add it.
if !haskey(vnd.varname_to_ranges, vn)
# Set the range for the new variable.
r = nextrange(vnd, val)
vnd.varname_to_ranges[vn] = r
# Resize the underlying vector to accommodate the new values.
resize!(vnd.values, r[end])
else
# Existing keys needs to be handled differently depending on
# whether the size of the value is increasing or decreasing.
r = vnd.varname_to_ranges[vn]
n_val = length(val)
n_r = length(r)
if n_val > n_r
# Remove the old range.
delete!(vnd.varname_to_ranges, vn)
# Add the new range.
r_new = nextrange(vnd, val)
vnd.varname_to_ranges[vn] = r_new
# Resize the underlying vector to accommodate the new values.
resize!(vnd.values, r_new[end])
else n_val < n_r
# Just decrease the current range.
vnd.varname_to_ranges[vn] = r[1]:r[1] + n_val - 1
end
# TODO: Keep track of unused ranges so we can perform sweeps
# every now and then to free up memory and re-contiguize the
# underlying vector.
end
return setindex!(vnd.values, val, vnd.varname_to_ranges[vn])
end
function BangBang.setindex!!(vnd::VarNameDict, val, vn::VarName)
setindex!(vnd, val, vn)
return vnd
end
function Base.iterate(vnd::VarNameDict, state=nothing)
res = state === nothing ? iterate(vnd.varname_to_ranges) : iterate(vnd.varname_to_ranges, state)
res === nothing && return nothing
(vn, range), state_new = res
return vn => vnd.values[range], state_new
end
Adding this to DPPL and exporting:
julia> using DynamicPPL
julia> model = DynamicPPL.TestUtils.demo_one_variable_multiple_constraints()
Model{typeof(DynamicPPL.TestUtils.demo_one_variable_multiple_constraints), (Symbol("##arg#375"),), (), (), Tuple{DataType}, Tuple{}, DefaultContext}(DynamicPPL.TestUtils.demo_one_variable_multiple_constraints, (var"##arg#375" =
Vector{Float64},), NamedTuple(), DefaultContext())
julia> x = rand(OrderedDict, model)
OrderedDict{Any, Any} with 4 entries:
x[1] => 1.18456
x[2] => 1.83516
x[3] => 0.726668
x[4:5] => [0.191595, 0.808405]
julia> vnd = VarNameDict(x)
VarNameDict{Any, Float64, Vector{Float64}, OrderedDict{Any, UnitRange{Int64}}} with 5 entries:
x[1] => [1.18456]
x[2] => [1.83516]
x[3] => [0.726668]
x[4:5] => [0.191595, 0.808405]
julia> vnd[@varname(x[1])]
1-element Vector{Float64}:
1.1845589710704487
julia> vnd[@varname(x[4:5])]
2-element Vector{Float64}:
0.19159541239321576
0.8084045876067844
julia> # Can create a `SimpleVarInfo` from a `VarNameDict`.
vi = SimpleVarInfo(vnd)
SimpleVarInfo(VarNameDict{Any, Float64, Vector{Float64}, OrderedDict{Any, UnitRange{Int64}}}(x[1] => [1.1845589710704487], x[2] => [1.8351581568345814], x[3] => [0.726667648763101], x[4:5] => [0.19159541239321576, 0.8084045876067844]), 0.0)
julia> # Inherits from `AbstractDict`
vi[@varname(x[1])]
1-element Vector{Float64}:
1.1845589710704487
julia> vi[@varname(x[4:5])]
2-element Vector{Float64}:
0.19159541239321576
0.8084045876067844
julia> vi[@varname(x[4:5][1])]
0.19159541239321576
julia> vi_linked = link(vi, model)
Transformed SimpleVarInfo(VarNameDict{Any, Float64, Vector{Float64}, OrderedDict{Any, UnitRange{Int64}}}(x[1] => [1.1845589710704487], x[2] => [0.6071306668031453], x[3] => [-1.2135885975805698], x[4:5] => [-1.4396767388530345]), -3.354889980209639)
julia> vi_invlinked = invlink(vi_linked, model)
SimpleVarInfo(VarNameDict{Any, Float64, Vector{Float64}, OrderedDict{Any, UnitRange{Int64}}}(x[1] => [1.1845589710704487], x[2] => [1.8351581568345814], x[3] => [0.726667648763101], x[4:5] => [0.19159541239321573, 0.8084045876067842]), -3.5819390425447497)
Uncertain if it's really worth it to make it a AbstractDict
:shrug:
I like the VarNameDict
mechanism. It combines the strength of NamedTuple
and OrderedDict
, by allowing more flexible Lens
-like recursive indexing behaviour for Dict, but also keeping values in a continuous vector-like container.
Related: https://github.com/TuringLang/DynamicPPL.jl/issues/358 https://github.com/TuringLang/DynamicPPL.jl/issues/416
EDIT: reserving a field for some metadata, other than value
would be helpful. For example
struct VarNameDict{K,T,M,V<:AbstractVector{T},D<:AbstractDict{K}} <: AbstractDict{K,T}
metadata::M
values::V
varname_to_ranges::D
end
I suggest that we rename VarNameDict
to VarDict
, but keep it a subtype fo AbstractDict
I've successfully used it with ComponentVector from ComponentArrays.jl many times before.
Out of curiosity, what's the issue preventing us from using ComponentArrays
as default for SimpleVarInfo
, that would save us from rolling out our home-baked VarNameDict
/VarDict
.
Out of curiosity, what's the issue preventing us from using ComponentArrays as default for SimpleVarInfo, that would save us from rolling out our home-baked VarNameDict/VarDict.
Indexing using VarName
mainly
EDIT: reserving a field for some metadata, other than value would be helpful. For example
I also thought about this but didn't want to add it in the initial version. We could do this, yeah, but it might also just be best to be kept separately in the varinfo? A bit uncertain.
I also thought about this but didn't want to add it in the initial version. We could do this, yeah, but it might also just be best to be kept separately in the varinfo? A bit uncertain.
Keeping it inside VarNameDict
/VarDict
feels more natural for me, where each key has its own value and metadata. But, yes, we can do this later.
Ah, one quite important thing we need from something like this VarDict
or anything similar: a way to convert back to the original form of the variable.
In Metadata
this role is covered by the distributions; if we don't want to keep those around, we need to keep some other information around :confused:
It's not too bad to save this "shape" information. I think it can be done transparently when adding new variables to VarDict
. Shape information can be extracted by the vectorise
function, and then saved to VarDict
.
https://github.com/TuringLang/DynamicPPL.jl/blob/ad53ca9cdc6adc686b970ef5f9fdee83ab33462d/src/simple_varinfo.jl#L420-L430