DynamicPPL.jl icon indicating copy to clipboard operation
DynamicPPL.jl copied to clipboard

[WIP] More flexibility in RHS of `~`, e.g. MeasureTheory.jl

Open torfjelde opened this issue 2 years ago • 14 comments

I've recently been thinking a bit about how it would be nice to support more than just Distribution from Distributions.jl on RHS of a ~ statement (and other values as samples, i.e. the LHS of ~), and as I was looking through some code today I realized that it shouldn't be too difficult.

Hence this PR which demonstrates what it would take to add this feature, using MeasureTheory.jl as an example use-case. All changes outside of src/measuretheory.jl are general changes that are required to accomodate non-Distribution on RHS of ~.

julia> using DynamicPPL, MeasureTheory

julia> @model function demo(; x=missing, n = x isa AbstractArray ? length(x) : 1)
           m ~ Normal(μ=0.0, σ=1.0)
           x ~ For(1:n) do i
               Normal(μ=m, σ=1.0)
           end
       end

demo (generic function with 1 method)

julia> m() # sampling
3-element Vector{Float64}:
 -1.106403180421966
  0.40711759833021666
 -2.46921196310957

julia> vi = VarInfo(m); m(vi, DefaultContext()) # evaluation
3-element view(::Vector{Float64}, 1:3) with eltype Float64:
 -0.10556975107508737
  0.578507546477508
 -1.4482491679503848

@cscherrer :)

torfjelde avatar Jul 27 '21 00:07 torfjelde

THIS WOULD BE SO GREAT!!!

Thanks @torfjelde for taking the initiative on this. BTW, a Soss model is already a measure, so this would make it easy to use Soss from within Turing. I wonder, what would it take to make a Turing model a measure, or even to have a wrapper around one that would make it behave in this way? That could be a good way to get things working together more easily.

cscherrer avatar Jul 27 '21 01:07 cscherrer

BTW, a Soss model is already a measure, so this would make it easy to use Soss from within Turing. I wonder, what would it take to make a Turing model a measure, or even to have a wrapper around one that would make it behave in this way? That could be a good way to get things working together more easily.

That's a good point! Honestly think the main issue is the linearization that DPPL currently requires. If you have a good way of linearizing the nested tuple sample from Soss, then it shouldn't be much of a leap from this branch:)

torfjelde avatar Jul 27 '21 01:07 torfjelde

There are some tricks in https://github.com/cscherrer/NestedTuples.jl, maybe something from there can help. We can't throw away the structure, but maybe this "leaf setter" thing? https://github.com/cscherrer/NestedTuples.jl#leaf-setter

Yeah, naming things is hard :)

cscherrer avatar Jul 27 '21 01:07 cscherrer

There are some tricks in https://github.com/cscherrer/NestedTuples.jl, maybe something from there can help. We can't throw away the structure, but maybe this "leaf setter" thing? https://github.com/cscherrer/NestedTuples.jl#leaf-setter

Yeah, naming things is hard :)

But how do you use DynamicHMC then? Surely HMC requires some form of linearization of the parameters?

And we don't need to throw away the structure, we only need to temporarily hide it and revert the change once we reach the Soss model:)

torfjelde avatar Jul 27 '21 12:07 torfjelde

But how do you use DynamicHMC then? Surely HMC requires some form of linearization of the parameters?

You define a transform, like say

t = as((a = asℝ₊, b = as((b1 = asℝ, b2 = as(Vector, as𝕀, 3), b3=CorrCholesky(4)))))

Soss automates composing one of these for a given model. Yes, these are linearized, but there are no names, and everything is stretched in order to have a density over ℝⁿ. I had thought this is also how Turing does things using Bijectors, but maybe that's wrong?

And we don't need to throw away the structure, we only need to temporarily hide it and revert the change once we reach the Soss model:)

Thinking some more about this, it seems like ParameterHandling.flatten could work well: https://invenia.github.io/ParameterHandling.jl/dev/#ParameterHandling.flatten

cscherrer avatar Jul 27 '21 14:07 cscherrer

I'll respond to you here to keep things a bit organized:)

But how do you use DynamicHMC then? Surely HMC requires some form of linearization of the parameters?

You define a transform, like say

t = as((a = asℝ₊, b = as((b1 = asℝ, b2 = as(Vector, as𝕀, 3), b3=CorrCholesky(4)))))

As things are currently, when storing a variable in the trace (VarInfo), we essentially flatten it into a vector and then store the variable names corresponding to ranges in the vector in a separate vector. Hence we'd need to do the same with a Soss-model's output, i.e. extract the linear shape + the variable-names used. I want whatever the above as does internally to make it into a vector.

Soss automates composing one of these for a given model. Yes, these are linearized, but there are no names, and everything is stretched in order to have a density over ℝⁿ.

No that's right, and that's my point:) But

I had thought this is also how Turing does things using Bijectors, but maybe that's wrong?

This is completely independent of the usage of Bijectors.jl though; we'll reshape back into the original shape of the variable before getting the transform from Bijectors.jl. We could also define bijector(measure) for the different measures to return the transformation taking as from the domain of the measure to , but that will only allow us to share some code for the transformation (e.g. the Bijectors.logpdf_with_trans), it won't do anything to address the issue that we need a flatten representation in VarInfo.

Thinking some more about this, it seems like ParameterHandling.flatten could work well: https://invenia.github.io/ParameterHandling.jl/dev/#ParameterHandling.flatten

Exactly, but we want the symbols, e.g. say we have this

x = (a = 1.0, b = (c = 2.0, ))
x ~ SossModel()

then internally we'd want something like

vals = [1.0, 2.0]
varnames = [VarName{Symbol("x.a")}(), VarName{Symbol("x.a.b.c")}()]

We also want a transformation for which we can compute the logdensity, but this should take x in it's original shape, not the vector vals, i.e. separate issue.

torfjelde avatar Jul 27 '21 15:07 torfjelde

Ok, I think I see. NestedTuples has a lenses function that... well maybe an example is best:

julia> x
(a = 2.0509701709447876, b = (b1 = -0.31411507894223795, b2 = [0.19141842948352514, 0.3248182896463582, 0.6726129111118845], b3 = LinearAlgebra.Cholesky{Float64, LinearAlgebra.UpperTriangular{Float64, Matrix{Float64}}}([1.0 -0.24542409737564375 -0.09619507309595421 -0.3347648223848326; 0.0 0.969415809870744 -0.6940440834677906 -0.25974375198705785; 0.0 0.0 0.7134769219220889 -0.14376094280898585; 0.0 0.0 0.0 0.8943145354515987], 'U', 0)))

julia> NestedTuples.lenses(x)
((@optic _.a), (@optic _.b1) ∘ (@optic _.b), (@optic _.b2) ∘ (@optic _.b), (@optic _.b3) ∘ (@optic _.b))

julia> typeof(NestedTuples.lenses(x)[3])
ComposedFunction{Accessors.PropertyLens{:b2}, Accessors.PropertyLens{:b}}

Currently I'm stopping when I hit an array, but Accessors can also handle these, for example

julia> @optic _.b.b2[3]
(@optic _[3]) ∘ (@optic _.b2) ∘ (@optic _.b)

cscherrer avatar Jul 27 '21 15:07 cscherrer

@torfjelde Is it correct that you need things entirely unrolled, so each name is for a scalar? Also, do you need to be able to reconstruct everything from the names alone, or can there be something carried along with it to make this easier?

cscherrer avatar Jul 27 '21 15:07 cscherrer

Currently I'm stopping when I hit an array, but Accessors can also handle these, for example

I've actually played around with replacing all this indexing behavior, etc. in Turing by the lenses from Setfield.jl (Accessors.jl seems more unstable from the README, and so it's somewhat unlikely we'll use that atm?).

The annoying case, and the case that stopped me from replacing VarName indexing with Setfield.jl's lenses is the handling of begin and end. It's difficult https://github.com/TuringLang/AbstractPPL.jl/pull/25 :)

But I think I have a way of addressing this actually.

@torfjelde Is it correct that you need things entirely unrolled, so each name is for a scalar? Also, do you need to be able to reconstruct everything from the names alone, or can there be something carried along with it to make this easier?

Not quite:) We also have lists of ranges and dists. So we can for example have

vals = randn(4)
ranges = [1:1, 2:4]
dists = [MvNormal(1, 1.0), MvNormal(3, 1.0)]
varnames = [@varname(x[1:1]), @varname(x[2:4])]

And whenever we encounter, say, x[2:4] in the the model, we can extract correctly sized value from VarInfo by using the size of the dists. See the reconstruct and vectorize functions that I've overloaded in this PR.

So essentially what I'm asking for is a reconstruct and vectorize for Soss-models:)

torfjelde avatar Jul 27 '21 15:07 torfjelde

I've actually played around with replacing all this indexing behavior, etc. in Turing by the lenses from Setfield.jl (Accessors.jl seems more unstable from the README, and so it's somewhat unlikely we'll use that atm?).

Yeah, I'm not too worried about stability. We have version dependencies anyway, plus it's just not that much code. Seems worth it IMO to have easy inroads to ongoing improvements. But Setfield it fine too, it shouldn't matter that much.

The annoying case, and the case that stopped me from replacing VarName indexing with Setfield.jl's lenses is the handling of begin and end. It's difficult TuringLang/AbstractPPL.jl#25 :)

I don't understand this at all. Is there a toy example?

Not quite:) We also have lists of ranges and dists.

I see, yeah that does complicate things.

So essentially what I'm asking for is a reconstruct and vectorize for Soss-models:)

Ok I'll have a look :)

cscherrer avatar Jul 27 '21 16:07 cscherrer

I don't understand this at all. Is there a toy example?

Eh no need, I've made a PR now anyways: https://github.com/TuringLang/AbstractPPL.jl/pull/26

I see, yeah that does complicate things.

Well, it sort of makes things easier:) Just look at the impls I have for MeasureTheory now. All we really need is a way to convert a named tuple in to a vector given a Soss-model. So like ParameterHandling.flatten, but without the closure.

torfjelde avatar Jul 27 '21 16:07 torfjelde

Yeah, I'm not too worried about stability. We have version dependencies anyway, plus it's just not that much code. Seems worth it IMO to have easy inroads to ongoing improvements. But Setfield it fine too, it shouldn't matter that much.

Not to press this point because I'm with you on what you just said, but I just realized that BangBang.jl uses Setfield.jl and we're likely to be making use of BangBang.jl in DPPL very soon. In particular I love that there's a @setfield!!! https://github.com/JuliaFolds/BangBang.jl/blob/master/src/setfield.jl

torfjelde avatar Jul 27 '21 16:07 torfjelde

Not to press this point because I'm with you on what you just said, but I just realized that BangBang.jl uses Setfield.jl and we're likely to be making use of BangBang.jl in DPPL very soon. In particular I love that there's a @setfield!!! https://github.com/JuliaFolds/BangBang.jl/blob/master/src/setfield.jl

Yeah, BangBang is pretty great. Have you seen Kaleido? Definitely worth a look as well

cscherrer avatar Jul 27 '21 17:07 cscherrer

Yeah, BangBang is pretty great. Have you seen Kaleido? Definitely worth a look as well

I think the aim of the two is a bit different, no? Both use Setfield.jl under the hood, but BangBang is all about using mutation when it makes sense and not when it doesn't, e.g. arrays are mutated inplace rather than copied. Kaleido looks cool as an extended Setfield.jl though! But seems like overkill for what we'll need in DPPL to accomplish what we want, e.g. https://github.com/TuringLang/AbstractPPL.jl/pull/26 .

torfjelde avatar Jul 28 '21 18:07 torfjelde

I think this is superseded by https://github.com/TuringLang/DynamicPPL.jl/pull/342

yebai avatar Nov 02 '22 20:11 yebai