Turing.jl icon indicating copy to clipboard operation
Turing.jl copied to clipboard

Sampling from prior using return value doesn't work with MvNormal

Open knuesel opened this issue 3 years ago • 8 comments

The guide shows how to sample from the prior using return values:

@model function gdemo(x, y)
    s ~ InverseGamma(2, 3)
    m ~ Normal(0, sqrt(s))
    x ~ Normal(m, sqrt(s))
    y ~ Normal(m, sqrt(s))
    return x, y
end

g_prior_sample = gdemo(missing, missing)
g_prior_sample()

I can rewrite it like this and it still works:

@model function gdemo(x)
    s ~ InverseGamma(2, 3)
    m ~ Normal(0, sqrt(s))
    x[1] ~ Normal(m, sqrt(s))
    x[2] ~ Normal(m, sqrt(s))
    return x
end

g_prior_sample = gdemo([missing, missing])
g_prior_sample()

However in my model I have many uses of MvNormal. The equivalent in this small example would be:

@model function gdemo(x)
    s ~ InverseGamma(2, 3)
    m ~ Normal(0, sqrt(s))
    x ~ MvNormal([m, m], sqrt(s))
    return x
end

g_prior_sample = gdemo([missing, missing])
g_prior_sample()

On 0.15.1 this fails with MethodError: no method matching loglikelihood(::IsoNormal, ::Vector{Union{Missing, Float64}}).

Stacktrace
  [1] observe(spl::DynamicPPL.SampleFromPrior, dist::IsoNormal, value::Vector{Union{Missing, Float64}}, vi::DynamicPPL.UntypedVarInfo{DynamicPPL.Metadata{Dict{DynamicPPL.VarName, Int64}, Vector{Distribution}, Vector{DynamicPPL.VarName}, Vector{Real}, Vector{Set{DynamicPPL.Selector}}}, Float64})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/wf0dU/src/context_implementations.jl:152
  [2] _tilde(sampler::DynamicPPL.SampleFromPrior, right::IsoNormal, left::Vector{Union{Missing, Float64}}, vi::DynamicPPL.UntypedVarInfo{DynamicPPL.Metadata{Dict{DynamicPPL.VarName, Int64}, Vector{Distribution}, Vector{DynamicPPL.VarName}, Vector{Real}, Vector{Set{DynamicPPL.Selector}}}, Float64})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/wf0dU/src/context_implementations.jl:109
  [3] tilde(ctx::DynamicPPL.DefaultContext, sampler::DynamicPPL.SampleFromPrior, right::IsoNormal, left::Vector{Union{Missing, Float64}}, vi::DynamicPPL.UntypedVarInfo{DynamicPPL.Metadata{Dict{DynamicPPL.VarName, Int64}, Vector{Distribution}, Vector{DynamicPPL.VarName}, Vector{Real}, Vector{Set{DynamicPPL.Selector}}}, Float64})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/wf0dU/src/context_implementations.jl:67
  [4] tilde_observe(ctx::DynamicPPL.DefaultContext, sampler::DynamicPPL.SampleFromPrior, right::IsoNormal, left::Vector{Union{Missing, Float64}}, vname::DynamicPPL.VarName{:x, Tuple{}}, vinds::Tuple{}, vi::DynamicPPL.UntypedVarInfo{DynamicPPL.Metadata{Dict{DynamicPPL.VarName, Int64}, Vector{Distribution}, Vector{DynamicPPL.VarName}, Vector{Real}, Vector{Set{DynamicPPL.Selector}}}, Float64})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/wf0dU/src/context_implementations.jl:89
  [5] #33
    @ ./In[81]:8 [inlined]
  [6] (::var"#33#34")(_rng::Random._GLOBAL_RNG, _model::DynamicPPL.Model{var"#33#34", (:x,), (), (), Tuple{Vector{Missing}}, Tuple{}}, _varinfo::DynamicPPL.UntypedVarInfo{DynamicPPL.Metadata{Dict{DynamicPPL.VarName, Int64}, Vector{Distribution}, Vector{DynamicPPL.VarName}, Vector{Real}, Vector{Set{DynamicPPL.Selector}}}, Float64}, _sampler::DynamicPPL.SampleFromPrior, _context::DynamicPPL.DefaultContext, x::Vector{Union{Missing, Float64}})
    @ Main ./none:0
  [7] macro expansion
    @ ~/.julia/packages/DynamicPPL/wf0dU/src/model.jl:0 [inlined]
  [8] _evaluate
    @ ~/.julia/packages/DynamicPPL/wf0dU/src/model.jl:154 [inlined]
  [9] evaluate_threadunsafe
    @ ~/.julia/packages/DynamicPPL/wf0dU/src/model.jl:127 [inlined]
 [10] Model
    @ ~/.julia/packages/DynamicPPL/wf0dU/src/model.jl:92 [inlined]
 [11] Model
    @ ~/.julia/packages/DynamicPPL/wf0dU/src/model.jl:91 [inlined]
 [12] (::DynamicPPL.Model{var"#33#34", (:x,), (), (), Tuple{Vector{Missing}}, Tuple{}})()
    @ DynamicPPL ~/.julia/packages/DynamicPPL/wf0dU/src/model.jl:98
 [13] top-level scope
    @ In[84]:2
 [14] eval
    @ ./boot.jl:360 [inlined]
 [15] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)
    @ Base ./loading.jl:1090

knuesel avatar Mar 12 '21 17:03 knuesel

Can you check if it still fails with the latest release (0.15.12)? I assume it does but it's good to check it first.

In general, I assume the problem is that the sampling of variables only works when the left hand side of ~ (such as x and y in the first example and x[1] and x[2] in the second) is missing - but not when an it is a vector of missing values: https://github.com/TuringLang/DynamicPPL.jl/blob/2b4c550a94a3dba14ea440b1f70ecb72cee2bb9b/src/compiler.jl#L31

A more reliable alternative to specifying x and y as missing would be to create a new model that marks the variables as missing by something like (untested so maybe it contains a bug...):

model = gdemo(3.0, 4.2)
model_missing = DynamicPPL.Model{(:x,:y)}(:gdemo_missing, model.f, model.args, model.defaults)
model_missing()

devmotion avatar Mar 12 '21 17:03 devmotion

Yes I get the same error with 0.15.12 (tested in Julia 1.5.3... I could not get 0.15.12 to install in Julia 1.6 rc1 or rc2).

I assume the problem is that the sampling of variables only works when the left hand side of ~ (such as x and y in the first example and x[1] and x[2] in the second) is missing - but not when an it is a vector of missing values

but the second example also uses a vector of missing and it does work:

@model function gdemo(x)
    s ~ InverseGamma(2, 3)
    m ~ Normal(0, sqrt(s))
    x[1] ~ Normal(m, sqrt(s))
    x[2] ~ Normal(m, sqrt(s))
    return x
end

g_prior_sample = gdemo([missing, missing])
g_prior_sample()  # no error

Note that this works:

loglikelihood(Normal(), Union{Missing,Float64}[1,2,3])

but this doesn't work:

loglikelihood(MvNormal([1.0, 2.0], 3.0), Union{Missing,Float64}[1 2 3; 4 5 6])

ERROR: MethodError: no method matching loglikelihood(::MvNormal{Float64,PDMats.ScalMat{Float64},Array{Float64,1}}, ::Array{Union{Missing, Float64},2})

So maybe it's a problem with Distributions.jl (or Turing should call disallowmissing before loglikelihood)?

knuesel avatar Mar 13 '21 11:03 knuesel

but the second example also uses a vector of missing and it does work:

The type of x is not relevant here - it only matters what values are on the left hand side of the ~ expressions in the model. So as mentioned above, in this example it only matters what x[1] and x[2] are. They are both missing, whereas in the failing example the variable x on the left hand side of the x ~ MvNormal(...) expression is a vector of missings.

devmotion avatar Mar 13 '21 14:03 devmotion

So maybe it's a problem with Distributions.jl (or Turing should call disallowmissing before loglikelihood)?

Maybe one could relax the type constraints in Distributions (AFAIK currently they require a AbstractMatrix{<:Real} for MvNormal) but to me it seems the main problem is that these methods are called at all. I still think it would be better to indicate that these variables are missing in the model definition instead of passing missing values around.

devmotion avatar Mar 13 '21 14:03 devmotion

I see, thanks. As per your suggestion, the following works:

@model function gdemo(x)
    s ~ InverseGamma(2, 3)
    m ~ Normal(0, sqrt(s))
    x ~ MvNormal([m, m], sqrt(s))
    return x
end

gp = gdemo([0.0, 0.0])

gpmis = DynamicPPL.Model{(:x,)}(:gpmis, gp.f, gp.args, gp.defaults)
gpmis()

It is a bit involved though compared to the method proposed in the guide...

Also, with x[i] ~ ... we can set selected values in the vector as missing, writing for example gdemo([0.0, missing]). Wouldn't it be desirable to have it work with x ~ MvNormal(...) too?

knuesel avatar Mar 24 '21 16:03 knuesel

Maybe this is an obvious comment since I'm kind of new to Turing, but I also ran into this and banged my head against the wall with it for a while, only to realize I think the correct solution here is just to pass in a single missing rather than a vector. With Turing v0.15.24 / DynamicPPL v0.10.20, it all works automatically:

julia> @model function gdemo(x)
           s ~ InverseGamma(2, 3)
           m ~ Normal(0, sqrt(s))
           x ~ MvNormal([m, m], sqrt(s))
           return x
       end
gdemo (generic function with 1 method)

julia> gdemo(missing)()
2-element Vector{Float64}:
 -1.5916479638357828
 -2.359586221863498

julia> gdemo([1,2])()
2-element Vector{Int64}:
 1
 2

Of course this makes it impossible to pass mixed [0.0, missing] style variables, which I agree would be nice, but at least the original case works fine.

marius311 avatar May 25 '21 18:05 marius311

I think the correct solution here is just to pass in a single missing rather than a vector.

Yes, this is what one can do, as pointed out in this discussion before.

Of course this makes it impossible to pass mixed [0.0, missing] style variables, which I agree would be nice, but at least the original case works fine.

In general one has to condition on the existing values which is not always as trivial as for a multivariate normal distribution with iid components.

devmotion avatar May 25 '21 18:05 devmotion

In general one has to condition on the existing values which is not always as trivial as for a multivariate normal distribution with iid components.

Sure, but this is one of the more common use-cases. The default behavior should be to simply infer the missing values from the sampling distribution, the same as would happen if one looped over univariate distributions. This is the behavior of pymc3, and it's quite nice because it allows for almost effortless handling of missing values under appropriate sampling distributions.

That being said, pymc3 does print a warning, presumably to remind the user that it's up to them to consider whether or not this behavior is actually correct for their model. This seems like a perfectly acceptable responsibility to offload to the user, imho.

bgroenks96 avatar Jun 27 '21 22:06 bgroenks96