DynamicPPL.jl icon indicating copy to clipboard operation
DynamicPPL.jl copied to clipboard

Unable to sample when conditioning on an array with `|`

Open ParadaCarleton opened this issue 3 years ago • 6 comments

Fails:

julia> @model function gdemo()
          μ ~ Normal(0, 1)  # Our prior is a standard normal
          x .~ Normal(μ, 1)  # Our data follows a normal distribution
          return nothing
       end
gdemo (generic function with 2 methods)

julia> chn = sample(rng, gdemo() | (x = [1, -1, 0],), NUTS(), MCMCThreads(), 1_000, 12)
ERROR: TaskFailedException

    nested task error: TaskFailedException
    
        nested task error: UndefVarError: x not defined
        Stacktrace:
          [1] gdemo(__model__::Model{typeof(gdemo), (), (), (), Tuple{}, Tuple{}, ConditionContext{(:x,), NamedTuple{(:x,), Tuple{Vector{Int64}}}, DefaultContext}}, __varinfo__::DynamicPPL.ThreadSafeVarInfo{UntypedVarInfo{DynamicPPL.Metadata{Dict{VarName, Int64}, Vector{Distribution}, Vector{VarName}, Vector{Real}, Vector{Set{DynamicPPL.Selector}}}, Float64}, Vector{Base.RefValue{Float64}}}, __context__::SamplingContext{SampleFromUniform, ConditionContext{(:x,), NamedTuple{(:x,), Tuple{Vector{Int64}}}, DefaultContext}, Xoshiro})
            @ Main ./REPL[12]:3
          [2] macro expansion
            @ ~/.julia/dev/DynamicPPL.jl/src/model.jl:498 [inlined]
          [3] _evaluate!!
            @ ~/.julia/dev/DynamicPPL.jl/src/model.jl:481 [inlined]
          [4] evaluate_threadsafe!!(model::Model{typeof(gdemo), (), (), (), Tuple{}, Tuple{}, ConditionContext{(:x,), NamedTuple{(:x,), Tuple{Vector{Int64}}}, DefaultContext}}, varinfo::UntypedVarInfo{DynamicPPL.Metadata{Dict{VarName, Int64}, Vector{Distribution}, Vector{VarName}, Vector{Real}, Vector{Set{DynamicPPL.Selector}}}, Float64}, context::SamplingContext{SampleFromUniform, DefaultContext, Xoshiro})
            @ DynamicPPL ~/.julia/dev/DynamicPPL.jl/src/model.jl:472
          [5] evaluate!!
            @ ~/.julia/dev/DynamicPPL.jl/src/model.jl:407 [inlined]
          [6] evaluate!!
            @ ~/.julia/dev/DynamicPPL.jl/src/model.jl:420 [inlined]
          [7] Model
            @ ~/.julia/dev/DynamicPPL.jl/src/model.jl:380 [inlined]
          [8] VarInfo
            @ ~/.julia/dev/DynamicPPL.jl/src/varinfo.jl:127 [inlined]
          [9] VarInfo
            @ ~/.julia/dev/DynamicPPL.jl/src/varinfo.jl:126 [inlined]
         [10] step(rng::Xoshiro, model::Model{typeof(gdemo), (), (), (), Tuple{}, Tuple{}, ConditionContext{(:x,), NamedTuple{(:x,), Tuple{Vector{Int64}}}, DefaultContext}}, spl::Sampler{NUTS{Turing.Essential.ForwardDiffAD{40}, (), AdvancedHMC.DiagEuclideanMetric}}; resume_from::Nothing, kwargs::Base.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:nadapts,), Tuple{Int64}}})
            @ DynamicPPL ~/.julia/dev/DynamicPPL.jl/src/sampler.jl:81
         [11] macro expansion
            @ ~/.julia/packages/AbstractMCMC/6aLyN/src/sample.jl:124 [inlined]
         [12] macro expansion
            @ ~/.julia/packages/AbstractMCMC/6aLyN/src/logging.jl:15 [inlined]
         [13] mcmcsample(rng::Xoshiro, model::Model{typeof(gdemo), (), (), (), Tuple{}, Tuple{}, ConditionContext{(:x,), NamedTuple{(:x,), Tuple{Vector{Int64}}}, DefaultContext}}, sampler::Sampler{NUTS{Turing.Essential.ForwardDiffAD{40}, (), AdvancedHMC.DiagEuclideanMetric}}, N::Int64; progress::Bool, progressname::String, callback::Nothing, discard_initial::Int64, thinning::Int64, chain_type::Type, kwargs::Base.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:nadapts,), Tuple{Int64}}})
            @ AbstractMCMC ~/.julia/packages/AbstractMCMC/6aLyN/src/sample.jl:115
         [14] #sample#42
            @ ~/.julia/packages/Turing/rl6ku/src/inference/hmc.jl:133 [inlined]
         [15] macro expansion
            @ ~/.julia/packages/AbstractMCMC/6aLyN/src/sample.jl:353 [inlined]
         [16] (::AbstractMCMC.var"#32#45"{Sampler{NUTS{Turing.Essential.ForwardDiffAD{40}, (), AdvancedHMC.DiagEuclideanMetric}}, Model{typeof(gdemo), (), (), (), Tuple{}, Tuple{}, ConditionContext{(:x,), NamedTuple{(:x,), Tuple{Vector{Int64}}}, DefaultContext}}, Xoshiro, UnitRange{Int64}, Bool, Base.Pairs{Symbol, UnionAll, Tuple{Symbol}, NamedTuple{(:chain_type,), Tuple{UnionAll}}}, Int64, Vector{Any}, Vector{UInt64}})()
            @ AbstractMCMC ./threadingconstructs.jl:178
    
    ...and 11 more exceptions.
    
    Stacktrace:
     [1] sync_end(c::Channel{Any})
       @ Base ./task.jl:381
     [2] macro expansion
       @ ./task.jl:400 [inlined]
     [3] macro expansion
       @ ~/.julia/packages/AbstractMCMC/6aLyN/src/sample.jl:342 [inlined]
     [4] (::AbstractMCMC.var"#31#44"{Bool, Base.Pairs{Symbol, UnionAll, Tuple{Symbol}, NamedTuple{(:chain_type,), Tuple{UnionAll}}}, Int64, Int64, Vector{Any}, Vector{UInt64}, Vector{Sampler{NUTS{Turing.Essential.ForwardDiffAD{40}, (), AdvancedHMC.DiagEuclideanMetric}}}, Vector{Model{typeof(gdemo), (), (), (), Tuple{}, Tuple{}, ConditionContext{(:x,), NamedTuple{(:x,), Tuple{Vector{Int64}}}, DefaultContext}}}, Vector{Xoshiro}, Int64, Int64})()
       @ AbstractMCMC ./task.jl:423
Stacktrace:
  [1] sync_end(c::Channel{Any})
    @ Base ./task.jl:381
  [2] macro expansion
    @ ./task.jl:400 [inlined]
  [3] macro expansion
    @ ~/.julia/packages/AbstractMCMC/6aLyN/src/sample.jl:320 [inlined]
  [4] macro expansion
    @ ~/.julia/packages/ProgressLogging/6KXlp/src/ProgressLogging.jl:328 [inlined]
  [5] macro expansion
    @ ~/.julia/packages/AbstractMCMC/6aLyN/src/logging.jl:8 [inlined]
  [6] mcmcsample(rng::Xoshiro, model::Model{typeof(gdemo), (), (), (), Tuple{}, Tuple{}, ConditionContext{(:x,), NamedTuple{(:x,), Tuple{Vector{Int64}}}, DefaultContext}}, sampler::Sampler{NUTS{Turing.Essential.ForwardDiffAD{40}, (), AdvancedHMC.DiagEuclideanMetric}}, ::MCMCThreads, N::Int64, nchains::Int64; progress::Bool, progressname::String, kwargs::Base.Pairs{Symbol, UnionAll, Tuple{Symbol}, NamedTuple{(:chain_type,), Tuple{UnionAll}}})
    @ AbstractMCMC ~/.julia/packages/AbstractMCMC/6aLyN/src/sample.jl:314
  [7] sample(rng::Xoshiro, model::Model{typeof(gdemo), (), (), (), Tuple{}, Tuple{}, ConditionContext{(:x,), NamedTuple{(:x,), Tuple{Vector{Int64}}}, DefaultContext}}, sampler::Sampler{NUTS{Turing.Essential.ForwardDiffAD{40}, (), AdvancedHMC.DiagEuclideanMetric}}, ensemble::MCMCThreads, N::Int64, n_chains::Int64; chain_type::Type, progress::Bool, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ Turing.Inference ~/.julia/packages/Turing/rl6ku/src/inference/Inference.jl:220
  [8] sample
    @ ~/.julia/packages/Turing/rl6ku/src/inference/Inference.jl:220 [inlined]
  [9] #sample#6
    @ ~/.julia/packages/Turing/rl6ku/src/inference/Inference.jl:205 [inlined]
 [10] sample(rng::Xoshiro, model::Model{typeof(gdemo), (), (), (), Tuple{}, Tuple{}, ConditionContext{(:x,), NamedTuple{(:x,), Tuple{Vector{Int64}}}, DefaultContext}}, alg::NUTS{Turing.Essential.ForwardDiffAD{40}, (), AdvancedHMC.DiagEuclideanMetric}, ensemble::MCMCThreads, N::Int64, n_chains::Int64)
    @ Turing.Inference ~/.julia/packages/Turing/rl6ku/src/inference/Inference.jl:205
 [11] top-level scope
    @ REPL[17]:1

But seems to work when x isn't a vector and .~ is removed.

ParadaCarleton avatar Mar 09 '22 03:03 ParadaCarleton

Error messages are much easier to parse if you use regular sampling (it also rules out multithreading related issues). Additionally, the example requires additional undefined variables and dependencies. A MWE without multithreading:

julia> using DynamicPPL, Distributions

julia> @model function gdemo()
          μ ~ Normal(0, 1)  # Our prior is a standard normal
          x .~ Normal(μ, 1)  # Our data follows a normal distribution
          return nothing
       end
gdemo (generic function with 2 methods)

julia> model = gdemo() | (x = [1, -1, 0],);

julia> model()
ERROR: UndefVarError: x not defined
Stacktrace:
 [1] gdemo(__model__::Model{typeof(gdemo), (), (), (), Tuple{}, Tuple{}, ConditionContext{(:x,), NamedTuple{(:x,), Tuple{Vector{Int64}}}, DefaultContext}}, __varinfo__::DynamicPPL.ThreadSafeVarInfo{UntypedVarInfo{DynamicPPL.Metadata{Dict{VarName, Int64}, Vector{Distribution}, Vector{VarName}, Vector{Real}, Vector{Set{DynamicPPL.Selector}}}, Float64}, Vector{Base.RefValue{Float64}}}, __context__::SamplingContext{SampleFromPrior, ConditionContext{(:x,), NamedTuple{(:x,), Tuple{Vector{Int64}}}, DefaultContext}, Random._GLOBAL_RNG})
   @ Main ./REPL[29]:3
 [2] macro expansion
   @ ~/.julia/dev/DynamicPPL/src/model.jl:498 [inlined]
 [3] _evaluate!!
   @ ~/.julia/dev/DynamicPPL/src/model.jl:481 [inlined]
 [4] evaluate_threadsafe!!(model::Model{typeof(gdemo), (), (), (), Tuple{}, Tuple{}, ConditionContext{(:x,), NamedTuple{(:x,), Tuple{Vector{Int64}}}, DefaultContext}}, varinfo::UntypedVarInfo{DynamicPPL.Metadata{Dict{VarName, Int64}, Vector{Distribution}, Vector{VarName}, Vector{Real}, Vector{Set{DynamicPPL.Selector}}}, Float64}, context::SamplingContext{SampleFromPrior, DefaultContext, Random._GLOBAL_RNG})
   @ DynamicPPL ~/.julia/dev/DynamicPPL/src/model.jl:472
 [5] evaluate!!
   @ ~/.julia/dev/DynamicPPL/src/model.jl:407 [inlined]
 [6] evaluate!! (repeats 2 times)
   @ ~/.julia/dev/DynamicPPL/src/model.jl:420 [inlined]
 [7] evaluate!!
   @ ~/.julia/dev/DynamicPPL/src/model.jl:428 [inlined]
 [8] (::Model{typeof(gdemo), (), (), (), Tuple{}, Tuple{}, ConditionContext{(:x,), NamedTuple{(:x,), Tuple{Vector{Int64}}}, DefaultContext}})()
   @ DynamicPPL ~/.julia/dev/DynamicPPL/src/model.jl:380
 [9] top-level scope
   @ REPL[34]:1

devmotion avatar Mar 09 '22 22:03 devmotion

~~The problem is https://github.com/TuringLang/DynamicPPL.jl/blob/f2eb6357d3d52efbc8457fca08806c73da3acc09/src/compiler.jl#L473-L475, which ends up in the model function in this case as~~

(DynamicPPL.unwrap_right_left_vns)((DynamicPPL.check_tilde_rhs)(Normal(μ, 1)), x, var"##vn#436")...

~~before x is sampled. The undotted version works (and should probably be preferred here) since it does not refer to x before sampling https://github.com/TuringLang/DynamicPPL.jl/blob/f2eb6357d3d52efbc8457fca08806c73da3acc09/src/compiler.jl#L423.~~

devmotion avatar Mar 10 '22 00:03 devmotion

Ah forget my last comment, we're in the !isassumption branch, of course.

devmotion avatar Mar 10 '22 00:03 devmotion

The problem is https://github.com/TuringLang/DynamicPPL.jl/blob/f2eb6357d3d52efbc8457fca08806c73da3acc09/src/compiler.jl#L450: There the value of x is unpacked correctly on the RHS but, of course, it can't be assigned to a non-existing variable x in-place. Changing .= there to = fixes the issue.

devmotion avatar Mar 10 '22 00:03 devmotion

Error messages are much easier to parse if you use regular sampling (it also rules out multithreading related issues). Additionally, the example requires additional undefined variables and dependencies. A MWE without multithreading:

Thanks, that makes sense; I'll keep it in mind for the future.

ParadaCarleton avatar Mar 11 '22 23:03 ParadaCarleton

Changing .= there to = fixes the issue.

Just for the record, making this change will make .~ inconsistent with standard broadcasting behavior, e.g. if you try to do something like x[1:2, 3:4] .~ fill(Normal(0, 1), 4) this will work when we use .= but will fail if we use =.

And I unfortunately can't see a clear path to adding support for something like this :/ There are ways of handling this in the model, but it's somewhat hacky (atm) (see #412).

torfjelde avatar Jun 11 '22 14:06 torfjelde