DynamicPPL.jl
DynamicPPL.jl copied to clipboard
Unable to sample when conditioning on an array with `|`
Fails:
julia> @model function gdemo()
μ ~ Normal(0, 1) # Our prior is a standard normal
x .~ Normal(μ, 1) # Our data follows a normal distribution
return nothing
end
gdemo (generic function with 2 methods)
julia> chn = sample(rng, gdemo() | (x = [1, -1, 0],), NUTS(), MCMCThreads(), 1_000, 12)
ERROR: TaskFailedException
nested task error: TaskFailedException
nested task error: UndefVarError: x not defined
Stacktrace:
[1] gdemo(__model__::Model{typeof(gdemo), (), (), (), Tuple{}, Tuple{}, ConditionContext{(:x,), NamedTuple{(:x,), Tuple{Vector{Int64}}}, DefaultContext}}, __varinfo__::DynamicPPL.ThreadSafeVarInfo{UntypedVarInfo{DynamicPPL.Metadata{Dict{VarName, Int64}, Vector{Distribution}, Vector{VarName}, Vector{Real}, Vector{Set{DynamicPPL.Selector}}}, Float64}, Vector{Base.RefValue{Float64}}}, __context__::SamplingContext{SampleFromUniform, ConditionContext{(:x,), NamedTuple{(:x,), Tuple{Vector{Int64}}}, DefaultContext}, Xoshiro})
@ Main ./REPL[12]:3
[2] macro expansion
@ ~/.julia/dev/DynamicPPL.jl/src/model.jl:498 [inlined]
[3] _evaluate!!
@ ~/.julia/dev/DynamicPPL.jl/src/model.jl:481 [inlined]
[4] evaluate_threadsafe!!(model::Model{typeof(gdemo), (), (), (), Tuple{}, Tuple{}, ConditionContext{(:x,), NamedTuple{(:x,), Tuple{Vector{Int64}}}, DefaultContext}}, varinfo::UntypedVarInfo{DynamicPPL.Metadata{Dict{VarName, Int64}, Vector{Distribution}, Vector{VarName}, Vector{Real}, Vector{Set{DynamicPPL.Selector}}}, Float64}, context::SamplingContext{SampleFromUniform, DefaultContext, Xoshiro})
@ DynamicPPL ~/.julia/dev/DynamicPPL.jl/src/model.jl:472
[5] evaluate!!
@ ~/.julia/dev/DynamicPPL.jl/src/model.jl:407 [inlined]
[6] evaluate!!
@ ~/.julia/dev/DynamicPPL.jl/src/model.jl:420 [inlined]
[7] Model
@ ~/.julia/dev/DynamicPPL.jl/src/model.jl:380 [inlined]
[8] VarInfo
@ ~/.julia/dev/DynamicPPL.jl/src/varinfo.jl:127 [inlined]
[9] VarInfo
@ ~/.julia/dev/DynamicPPL.jl/src/varinfo.jl:126 [inlined]
[10] step(rng::Xoshiro, model::Model{typeof(gdemo), (), (), (), Tuple{}, Tuple{}, ConditionContext{(:x,), NamedTuple{(:x,), Tuple{Vector{Int64}}}, DefaultContext}}, spl::Sampler{NUTS{Turing.Essential.ForwardDiffAD{40}, (), AdvancedHMC.DiagEuclideanMetric}}; resume_from::Nothing, kwargs::Base.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:nadapts,), Tuple{Int64}}})
@ DynamicPPL ~/.julia/dev/DynamicPPL.jl/src/sampler.jl:81
[11] macro expansion
@ ~/.julia/packages/AbstractMCMC/6aLyN/src/sample.jl:124 [inlined]
[12] macro expansion
@ ~/.julia/packages/AbstractMCMC/6aLyN/src/logging.jl:15 [inlined]
[13] mcmcsample(rng::Xoshiro, model::Model{typeof(gdemo), (), (), (), Tuple{}, Tuple{}, ConditionContext{(:x,), NamedTuple{(:x,), Tuple{Vector{Int64}}}, DefaultContext}}, sampler::Sampler{NUTS{Turing.Essential.ForwardDiffAD{40}, (), AdvancedHMC.DiagEuclideanMetric}}, N::Int64; progress::Bool, progressname::String, callback::Nothing, discard_initial::Int64, thinning::Int64, chain_type::Type, kwargs::Base.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:nadapts,), Tuple{Int64}}})
@ AbstractMCMC ~/.julia/packages/AbstractMCMC/6aLyN/src/sample.jl:115
[14] #sample#42
@ ~/.julia/packages/Turing/rl6ku/src/inference/hmc.jl:133 [inlined]
[15] macro expansion
@ ~/.julia/packages/AbstractMCMC/6aLyN/src/sample.jl:353 [inlined]
[16] (::AbstractMCMC.var"#32#45"{Sampler{NUTS{Turing.Essential.ForwardDiffAD{40}, (), AdvancedHMC.DiagEuclideanMetric}}, Model{typeof(gdemo), (), (), (), Tuple{}, Tuple{}, ConditionContext{(:x,), NamedTuple{(:x,), Tuple{Vector{Int64}}}, DefaultContext}}, Xoshiro, UnitRange{Int64}, Bool, Base.Pairs{Symbol, UnionAll, Tuple{Symbol}, NamedTuple{(:chain_type,), Tuple{UnionAll}}}, Int64, Vector{Any}, Vector{UInt64}})()
@ AbstractMCMC ./threadingconstructs.jl:178
...and 11 more exceptions.
Stacktrace:
[1] sync_end(c::Channel{Any})
@ Base ./task.jl:381
[2] macro expansion
@ ./task.jl:400 [inlined]
[3] macro expansion
@ ~/.julia/packages/AbstractMCMC/6aLyN/src/sample.jl:342 [inlined]
[4] (::AbstractMCMC.var"#31#44"{Bool, Base.Pairs{Symbol, UnionAll, Tuple{Symbol}, NamedTuple{(:chain_type,), Tuple{UnionAll}}}, Int64, Int64, Vector{Any}, Vector{UInt64}, Vector{Sampler{NUTS{Turing.Essential.ForwardDiffAD{40}, (), AdvancedHMC.DiagEuclideanMetric}}}, Vector{Model{typeof(gdemo), (), (), (), Tuple{}, Tuple{}, ConditionContext{(:x,), NamedTuple{(:x,), Tuple{Vector{Int64}}}, DefaultContext}}}, Vector{Xoshiro}, Int64, Int64})()
@ AbstractMCMC ./task.jl:423
Stacktrace:
[1] sync_end(c::Channel{Any})
@ Base ./task.jl:381
[2] macro expansion
@ ./task.jl:400 [inlined]
[3] macro expansion
@ ~/.julia/packages/AbstractMCMC/6aLyN/src/sample.jl:320 [inlined]
[4] macro expansion
@ ~/.julia/packages/ProgressLogging/6KXlp/src/ProgressLogging.jl:328 [inlined]
[5] macro expansion
@ ~/.julia/packages/AbstractMCMC/6aLyN/src/logging.jl:8 [inlined]
[6] mcmcsample(rng::Xoshiro, model::Model{typeof(gdemo), (), (), (), Tuple{}, Tuple{}, ConditionContext{(:x,), NamedTuple{(:x,), Tuple{Vector{Int64}}}, DefaultContext}}, sampler::Sampler{NUTS{Turing.Essential.ForwardDiffAD{40}, (), AdvancedHMC.DiagEuclideanMetric}}, ::MCMCThreads, N::Int64, nchains::Int64; progress::Bool, progressname::String, kwargs::Base.Pairs{Symbol, UnionAll, Tuple{Symbol}, NamedTuple{(:chain_type,), Tuple{UnionAll}}})
@ AbstractMCMC ~/.julia/packages/AbstractMCMC/6aLyN/src/sample.jl:314
[7] sample(rng::Xoshiro, model::Model{typeof(gdemo), (), (), (), Tuple{}, Tuple{}, ConditionContext{(:x,), NamedTuple{(:x,), Tuple{Vector{Int64}}}, DefaultContext}}, sampler::Sampler{NUTS{Turing.Essential.ForwardDiffAD{40}, (), AdvancedHMC.DiagEuclideanMetric}}, ensemble::MCMCThreads, N::Int64, n_chains::Int64; chain_type::Type, progress::Bool, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ Turing.Inference ~/.julia/packages/Turing/rl6ku/src/inference/Inference.jl:220
[8] sample
@ ~/.julia/packages/Turing/rl6ku/src/inference/Inference.jl:220 [inlined]
[9] #sample#6
@ ~/.julia/packages/Turing/rl6ku/src/inference/Inference.jl:205 [inlined]
[10] sample(rng::Xoshiro, model::Model{typeof(gdemo), (), (), (), Tuple{}, Tuple{}, ConditionContext{(:x,), NamedTuple{(:x,), Tuple{Vector{Int64}}}, DefaultContext}}, alg::NUTS{Turing.Essential.ForwardDiffAD{40}, (), AdvancedHMC.DiagEuclideanMetric}, ensemble::MCMCThreads, N::Int64, n_chains::Int64)
@ Turing.Inference ~/.julia/packages/Turing/rl6ku/src/inference/Inference.jl:205
[11] top-level scope
@ REPL[17]:1
But seems to work when x isn't a vector and .~ is removed.
Error messages are much easier to parse if you use regular sampling (it also rules out multithreading related issues). Additionally, the example requires additional undefined variables and dependencies. A MWE without multithreading:
julia> using DynamicPPL, Distributions
julia> @model function gdemo()
μ ~ Normal(0, 1) # Our prior is a standard normal
x .~ Normal(μ, 1) # Our data follows a normal distribution
return nothing
end
gdemo (generic function with 2 methods)
julia> model = gdemo() | (x = [1, -1, 0],);
julia> model()
ERROR: UndefVarError: x not defined
Stacktrace:
[1] gdemo(__model__::Model{typeof(gdemo), (), (), (), Tuple{}, Tuple{}, ConditionContext{(:x,), NamedTuple{(:x,), Tuple{Vector{Int64}}}, DefaultContext}}, __varinfo__::DynamicPPL.ThreadSafeVarInfo{UntypedVarInfo{DynamicPPL.Metadata{Dict{VarName, Int64}, Vector{Distribution}, Vector{VarName}, Vector{Real}, Vector{Set{DynamicPPL.Selector}}}, Float64}, Vector{Base.RefValue{Float64}}}, __context__::SamplingContext{SampleFromPrior, ConditionContext{(:x,), NamedTuple{(:x,), Tuple{Vector{Int64}}}, DefaultContext}, Random._GLOBAL_RNG})
@ Main ./REPL[29]:3
[2] macro expansion
@ ~/.julia/dev/DynamicPPL/src/model.jl:498 [inlined]
[3] _evaluate!!
@ ~/.julia/dev/DynamicPPL/src/model.jl:481 [inlined]
[4] evaluate_threadsafe!!(model::Model{typeof(gdemo), (), (), (), Tuple{}, Tuple{}, ConditionContext{(:x,), NamedTuple{(:x,), Tuple{Vector{Int64}}}, DefaultContext}}, varinfo::UntypedVarInfo{DynamicPPL.Metadata{Dict{VarName, Int64}, Vector{Distribution}, Vector{VarName}, Vector{Real}, Vector{Set{DynamicPPL.Selector}}}, Float64}, context::SamplingContext{SampleFromPrior, DefaultContext, Random._GLOBAL_RNG})
@ DynamicPPL ~/.julia/dev/DynamicPPL/src/model.jl:472
[5] evaluate!!
@ ~/.julia/dev/DynamicPPL/src/model.jl:407 [inlined]
[6] evaluate!! (repeats 2 times)
@ ~/.julia/dev/DynamicPPL/src/model.jl:420 [inlined]
[7] evaluate!!
@ ~/.julia/dev/DynamicPPL/src/model.jl:428 [inlined]
[8] (::Model{typeof(gdemo), (), (), (), Tuple{}, Tuple{}, ConditionContext{(:x,), NamedTuple{(:x,), Tuple{Vector{Int64}}}, DefaultContext}})()
@ DynamicPPL ~/.julia/dev/DynamicPPL/src/model.jl:380
[9] top-level scope
@ REPL[34]:1
~~The problem is https://github.com/TuringLang/DynamicPPL.jl/blob/f2eb6357d3d52efbc8457fca08806c73da3acc09/src/compiler.jl#L473-L475, which ends up in the model function in this case as~~
(DynamicPPL.unwrap_right_left_vns)((DynamicPPL.check_tilde_rhs)(Normal(μ, 1)), x, var"##vn#436")...
~~before x is sampled. The undotted version works (and should probably be preferred here) since it does not refer to x before sampling https://github.com/TuringLang/DynamicPPL.jl/blob/f2eb6357d3d52efbc8457fca08806c73da3acc09/src/compiler.jl#L423.~~
Ah forget my last comment, we're in the !isassumption branch, of course.
The problem is https://github.com/TuringLang/DynamicPPL.jl/blob/f2eb6357d3d52efbc8457fca08806c73da3acc09/src/compiler.jl#L450: There the value of x is unpacked correctly on the RHS but, of course, it can't be assigned to a non-existing variable x in-place. Changing .= there to = fixes the issue.
Error messages are much easier to parse if you use regular sampling (it also rules out multithreading related issues). Additionally, the example requires additional undefined variables and dependencies. A MWE without multithreading:
Thanks, that makes sense; I'll keep it in mind for the future.
Changing .= there to = fixes the issue.
Just for the record, making this change will make .~ inconsistent with standard broadcasting behavior, e.g. if you try to do something like x[1:2, 3:4] .~ fill(Normal(0, 1), 4) this will work when we use .= but will fail if we use =.
And I unfortunately can't see a clear path to adding support for something like this :/ There are ways of handling this in the model, but it's somewhat hacky (atm) (see #412).