Turing.jl icon indicating copy to clipboard operation
Turing.jl copied to clipboard

missing keyword arguments not property processed

Open daeh opened this issue 1 year ago • 5 comments

The documentation gives an example of how the sample macro can be used to either condition a model or sample RVs:

@model function gdemo(x, ::Type{T}=Float64) where {T}
    if x === missing
        # Initialize `x` if missing
        x = Vector{T}(undef, 2)
    end
    s² ~ InverseGamma(2, 3)
    m ~ Normal(0, sqrt(s²))
    for i in eachindex(x)
        x[i] ~ Normal(m, sqrt(s²))
    end
end

# Construct a model with x = missing
model = gdemo(missing)
c = sample(model, HMC(0.01, 5), 500)

If x is turned into a keyword argument, this example produces an error:

using Turing

@model function gdemo_kw(::Type{T}=Float64; x=missing) where {T}
    if x === missing
        # Initialize `x` if missing
        x = Vector{T}(undef, 2)
    end
    s² ~ InverseGamma(2, 3)
    m ~ Normal(0, sqrt(s²))
    for i in eachindex(x)
        x[i] ~ Normal(m, sqrt(s²))
    end
end

# Construct a model with x = missing
model_kw = gdemo_kw(; x=missing)
c_kw = sample(model_kw, HMC(0.01, 5), 500)
julia> c_kw = sample(model_kw, HMC(0.01, 5), 500)
ERROR: DomainError with Dual{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag, Float64}}(NaN,NaN,NaN):
Normal: the condition σ >= zero(σ) is not satisfied.
Stacktrace:
  [1] #371
    @ ~/.julia/packages/Distributions/ji8PW/src/univariate/continuous/normal.jl:37 [inlined]
  [2] check_args
    @ ~/.julia/packages/Distributions/ji8PW/src/utils.jl:89 [inlined]
  [3] #Normal#370
    @ ~/.julia/packages/Distributions/ji8PW/src/univariate/continuous/normal.jl:37 [inlined]
  [4] Normal
    @ ~/.julia/packages/Distributions/ji8PW/src/univariate/continuous/normal.jl:36 [inlined]
  [5] Normal
    @ ~/.julia/packages/Distributions/ji8PW/src/univariate/continuous/normal.jl:42 [inlined]
  [6] gdemo_kw(__model__::DynamicPPL.Model{…}, __varinfo__::DynamicPPL.ThreadSafeVarInfo{…}, __context__::DynamicPPL.SamplingContext{…}, arg#225::DynamicPPL.TypeWrap{…}; x::Missing)
    @ Main ./REPL[2]:7
  [7] gdemo_kw
    @ ./REPL[2]:1 [inlined]
  [8] _evaluate!!
    @ ~/.julia/packages/DynamicPPL/E4kDs/src/model.jl:963 [inlined]
  [9] evaluate_threadsafe!!(model::DynamicPPL.Model{…}, varinfo::DynamicPPL.TypedVarInfo{…}, context::DynamicPPL.SamplingContext{…})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/E4kDs/src/model.jl:952
 [10] evaluate!!(model::DynamicPPL.Model{…}, varinfo::DynamicPPL.TypedVarInfo{…}, context::DynamicPPL.SamplingContext{…})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/E4kDs/src/model.jl:887
 [11] logdensity
    @ ~/.julia/packages/DynamicPPL/E4kDs/src/logdensityfunction.jl:94 [inlined]
 [12] Fix1
    @ ./operators.jl:1118 [inlined]
 [13] vector_mode_dual_eval!
    @ ~/.julia/packages/ForwardDiff/PcZ48/src/apiutils.jl:24 [inlined]
 [14] vector_mode_gradient!(result::DiffResults.MutableDiffResult{…}, f::Base.Fix1{…}, x::Vector{…}, cfg::ForwardDiff.GradientConfig{…})
    @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/gradient.jl:96
 [15] gradient!
    @ ~/.julia/packages/ForwardDiff/PcZ48/src/gradient.jl:37 [inlined]
 [16] gradient!
    @ ~/.julia/packages/ForwardDiff/PcZ48/src/gradient.jl:35 [inlined]
 [17] logdensity_and_gradient
    @ ~/.julia/packages/LogDensityProblemsAD/rBlLq/ext/LogDensityProblemsADForwardDiffExt.jl:118 [inlined]
 [18] ∂logπ∂θ
    @ ~/.julia/packages/Turing/IyijE/src/mcmc/hmc.jl:159 [inlined]
 [19] ∂H∂θ
    @ ~/.julia/packages/AdvancedHMC/AlvV4/src/hamiltonian.jl:38 [inlined]
 [20] step(lf::AdvancedHMC.Leapfrog{…}, h::AdvancedHMC.Hamiltonian{…}, z::AdvancedHMC.PhasePoint{…}, n_steps::Int64; fwd::Bool, full_trajectory::Val{…})
    @ AdvancedHMC ~/.julia/packages/AdvancedHMC/AlvV4/src/integrator.jl:229
 [21] step
    @ ~/.julia/packages/AdvancedHMC/AlvV4/src/integrator.jl:199 [inlined]
 [22] sample_phasepoint
    @ ~/.julia/packages/AdvancedHMC/AlvV4/src/trajectory.jl:323 [inlined]
 [23] transition
    @ ~/.julia/packages/AdvancedHMC/AlvV4/src/trajectory.jl:262 [inlined]
 [24] transition(rng::Random.TaskLocalRNG, h::AdvancedHMC.Hamiltonian{…}, κ::AdvancedHMC.HMCKernel{…}, z::AdvancedHMC.PhasePoint{…})
    @ AdvancedHMC ~/.julia/packages/AdvancedHMC/AlvV4/src/sampler.jl:59
 [25] step(rng::Random.TaskLocalRNG, model::DynamicPPL.Model{…}, spl::DynamicPPL.Sampler{…}, state::Turing.Inference.HMCState{…}; nadapts::Int64, kwargs::@Kwargs{})
    @ Turing.Inference ~/.julia/packages/Turing/IyijE/src/mcmc/hmc.jl:240
 [26] step(rng::Random.TaskLocalRNG, model::DynamicPPL.Model{…}, spl::DynamicPPL.Sampler{…}, state::Turing.Inference.HMCState{…})
    @ Turing.Inference ~/.julia/packages/Turing/IyijE/src/mcmc/hmc.jl:226
 [27] macro expansion
    @ ~/.julia/packages/AbstractMCMC/YrmkI/src/sample.jl:176 [inlined]
 [28] macro expansion
    @ ~/.julia/packages/ProgressLogging/6KXlp/src/ProgressLogging.jl:328 [inlined]
 [29] macro expansion
    @ ~/.julia/packages/AbstractMCMC/YrmkI/src/logging.jl:9 [inlined]
 [30] mcmcsample(rng::Random.TaskLocalRNG, model::DynamicPPL.Model{…}, sampler::DynamicPPL.Sampler{…}, N::Int64; progress::Bool, progressname::String, callback::Nothing, discard_initial::Int64, thinning::Int64, chain_type::Type, initial_state::Nothing, kwargs::@Kwargs{})
    @ AbstractMCMC ~/.julia/packages/AbstractMCMC/YrmkI/src/sample.jl:120
 [31] sample(rng::Random.TaskLocalRNG, model::DynamicPPL.Model{…}, sampler::DynamicPPL.Sampler{…}, N::Int64; chain_type::Type, resume_from::Nothing, initial_state::Nothing, kwargs::@Kwargs{})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/E4kDs/src/sampler.jl:93
 [32] sample
    @ ~/.julia/packages/DynamicPPL/E4kDs/src/sampler.jl:83 [inlined]
 [33] #sample#4
    @ ~/.julia/packages/Turing/IyijE/src/mcmc/Inference.jl:263 [inlined]
 [34] sample
    @ ~/.julia/packages/Turing/IyijE/src/mcmc/Inference.jl:256 [inlined]
 [35] #sample#3
    @ ~/.julia/packages/Turing/IyijE/src/mcmc/Inference.jl:253 [inlined]
 [36] sample(model::DynamicPPL.Model{…}, alg::HMC{…}, N::Int64)
    @ Turing.Inference ~/.julia/packages/Turing/IyijE/src/mcmc/Inference.jl:247
 [37] top-level scope
    @ REPL[4]:1
Some type information was truncated. Use `show(err)` to see complete types.
Julia Version 1.10.4
Commit 48d4fd48430 (2024-06-04 10:41 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: macOS (arm64-apple-darwin22.4.0)
  CPU: 12 × Apple M2 Pro
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, apple-m1)
Threads: 8 default, 0 interactive, 4 GC (on 8 virtual cores)
Environment:
  JULIA_EDITOR = code
  JULIA_NUM_THREADS = 8

  [0bf59076] AdvancedHMC v0.6.1
  [cbdf2221] AlgebraOfGraphics v0.6.19
  [c7e460c6] ArgParse v1.2.0
  [131c737c] ArviZ v0.10.5
  [4a6e88f0] ArviZPythonPlots v0.1.5
  [336ed68f] CSV v0.10.14
⌃ [13f3f980] CairoMakie v0.11.11
  [324d7699] CategoricalArrays v0.10.8
  [a93c6f00] DataFrames v1.6.1
  [1a297f60] FillArrays v1.11.0
  [663a7486] FreeTypeAbstraction v0.10.3
  [682c06a0] JSON v0.21.4
  [98e50ef6] JuliaFormatter v1.0.56
⌅ [ee78f7c6] Makie v0.20.10
  [7f7a1694] Optimization v3.25.1
  [b1d3bc72] Pathfinder v0.8.7
  [f27b6e38] Polynomials v4.0.9
  [438e738f] PyCall v1.96.4
  [37e2e3b7] ReverseDiff v1.15.3
  [295af30f] Revise v3.5.14
  [2913bbd2] StatsBase v0.34.3
  [f3b207a7] StatsPlots v0.15.7
  [fce5fe82] Turing v0.32.3
  [e88e6eb3] Zygote v0.6.70

daeh avatar Jun 06 '24 22:06 daeh

tried again with v0.33.0

julia> c_kw = sample(model_kw, HMC(0.01, 5), 500)
ERROR: DomainError with Dual{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag, Float64}}(NaN,NaN,NaN):
Normal: the condition σ >= zero(σ) is not satisfied.
Stacktrace:
  [1] #371
    @ ~/.julia/packages/Distributions/ji8PW/src/univariate/continuous/normal.jl:37 [inlined]
  [2] check_args
    @ ~/.julia/packages/Distributions/ji8PW/src/utils.jl:89 [inlined]
  [3] #Normal#370
    @ ~/.julia/packages/Distributions/ji8PW/src/univariate/continuous/normal.jl:37 [inlined]
  [4] Normal
    @ ~/.julia/packages/Distributions/ji8PW/src/univariate/continuous/normal.jl:36 [inlined]
  [5] Normal
    @ ~/.julia/packages/Distributions/ji8PW/src/univariate/continuous/normal.jl:42 [inlined]
  [6] gdemo_kw(__model__::DynamicPPL.Model{…}, __varinfo__::DynamicPPL.ThreadSafeVarInfo{…}, __context__::DynamicPPL.SamplingContext{…}, arg#225::DynamicPPL.TypeWrap{…}; x::Missing)
    @ Main ./REPL[5]:7
  [7] gdemo_kw
    @ ./REPL[5]:1 [inlined]
  [8] _evaluate!!
    @ ~/.julia/packages/DynamicPPL/E4kDs/src/model.jl:963 [inlined]
  [9] evaluate_threadsafe!!(model::DynamicPPL.Model{…}, varinfo::DynamicPPL.TypedVarInfo{…}, context::DynamicPPL.SamplingContext{…})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/E4kDs/src/model.jl:952
 [10] evaluate!!(model::DynamicPPL.Model{…}, varinfo::DynamicPPL.TypedVarInfo{…}, context::DynamicPPL.SamplingContext{…})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/E4kDs/src/model.jl:887
 [11] logdensity
    @ ~/.julia/packages/DynamicPPL/E4kDs/src/logdensityfunction.jl:94 [inlined]
 [12] Fix1
    @ ./operators.jl:1118 [inlined]
 [13] vector_mode_dual_eval!
    @ ~/.julia/packages/ForwardDiff/PcZ48/src/apiutils.jl:24 [inlined]
 [14] vector_mode_gradient!(result::DiffResults.MutableDiffResult{…}, f::Base.Fix1{…}, x::Vector{…}, cfg::ForwardDiff.GradientConfig{…})
    @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/gradient.jl:96
 [15] gradient!
    @ ~/.julia/packages/ForwardDiff/PcZ48/src/gradient.jl:37 [inlined]
 [16] gradient!
    @ ~/.julia/packages/ForwardDiff/PcZ48/src/gradient.jl:35 [inlined]
 [17] logdensity_and_gradient
    @ ~/.julia/packages/LogDensityProblemsAD/rBlLq/ext/LogDensityProblemsADForwardDiffExt.jl:118 [inlined]
 [18] ∂logπ∂θ
    @ ~/.julia/packages/Turing/iRdIB/src/mcmc/hmc.jl:159 [inlined]
 [19] ∂H∂θ
    @ ~/.julia/packages/AdvancedHMC/AlvV4/src/hamiltonian.jl:38 [inlined]
 [20] step(lf::AdvancedHMC.Leapfrog{…}, h::AdvancedHMC.Hamiltonian{…}, z::AdvancedHMC.PhasePoint{…}, n_steps::Int64; fwd::Bool, full_trajectory::Val{…})
    @ AdvancedHMC ~/.julia/packages/AdvancedHMC/AlvV4/src/integrator.jl:229
 [21] step
    @ ~/.julia/packages/AdvancedHMC/AlvV4/src/integrator.jl:199 [inlined]
 [22] sample_phasepoint
    @ ~/.julia/packages/AdvancedHMC/AlvV4/src/trajectory.jl:323 [inlined]
 [23] transition
    @ ~/.julia/packages/AdvancedHMC/AlvV4/src/trajectory.jl:262 [inlined]
 [24] transition(rng::Random.TaskLocalRNG, h::AdvancedHMC.Hamiltonian{…}, κ::AdvancedHMC.HMCKernel{…}, z::AdvancedHMC.PhasePoint{…})
    @ AdvancedHMC ~/.julia/packages/AdvancedHMC/AlvV4/src/sampler.jl:59
 [25] step(rng::Random.TaskLocalRNG, model::DynamicPPL.Model{…}, spl::DynamicPPL.Sampler{…}, state::Turing.Inference.HMCState{…}; nadapts::Int64, kwargs::@Kwargs{})
    @ Turing.Inference ~/.julia/packages/Turing/iRdIB/src/mcmc/hmc.jl:240
 [26] step(rng::Random.TaskLocalRNG, model::DynamicPPL.Model{…}, spl::DynamicPPL.Sampler{…}, state::Turing.Inference.HMCState{…})
    @ Turing.Inference ~/.julia/packages/Turing/iRdIB/src/mcmc/hmc.jl:226
 [27] macro expansion
    @ ~/.julia/packages/AbstractMCMC/YrmkI/src/sample.jl:176 [inlined]
 [28] macro expansion
    @ ~/.julia/packages/ProgressLogging/6KXlp/src/ProgressLogging.jl:328 [inlined]
 [29] macro expansion
    @ ~/.julia/packages/AbstractMCMC/YrmkI/src/logging.jl:9 [inlined]
 [30] mcmcsample(rng::Random.TaskLocalRNG, model::DynamicPPL.Model{…}, sampler::DynamicPPL.Sampler{…}, N::Int64; progress::Bool, progressname::String, callback::Nothing, discard_initial::Int64, thinning::Int64, chain_type::Type, initial_state::Nothing, kwargs::@Kwargs{})
    @ AbstractMCMC ~/.julia/packages/AbstractMCMC/YrmkI/src/sample.jl:120
 [31] sample(rng::Random.TaskLocalRNG, model::DynamicPPL.Model{…}, sampler::DynamicPPL.Sampler{…}, N::Int64; chain_type::Type, resume_from::Nothing, initial_state::Nothing, kwargs::@Kwargs{})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/E4kDs/src/sampler.jl:93
 [32] sample
    @ ~/.julia/packages/DynamicPPL/E4kDs/src/sampler.jl:83 [inlined]
 [33] #sample#4
    @ ~/.julia/packages/Turing/iRdIB/src/mcmc/Inference.jl:263 [inlined]
 [34] sample
    @ ~/.julia/packages/Turing/iRdIB/src/mcmc/Inference.jl:256 [inlined]
 [35] #sample#3
    @ ~/.julia/packages/Turing/iRdIB/src/mcmc/Inference.jl:253 [inlined]
 [36] sample(model::DynamicPPL.Model{…}, alg::HMC{…}, N::Int64)
    @ Turing.Inference ~/.julia/packages/Turing/iRdIB/src/mcmc/Inference.jl:247
 [37] top-level scope
    @ REPL[7]:1
Some type information was truncated. Use `show(err)` to see complete types.
Commit 48d4fd48430 (2024-06-04 10:41 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: macOS (arm64-apple-darwin22.4.0)
  CPU: 12 × Apple M2 Pro
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, apple-m1)
Threads: 8 default, 0 interactive, 4 GC (on 8 virtual cores)
Environment:
  JULIA_EDITOR = code
  JULIA_NUM_THREADS = 8

Status `~/coding/-GitRepos/knobe-counterfactuals/Project.toml`
  [98e50ef6] JuliaFormatter v1.0.56
  [295af30f] Revise v3.5.14
  [fce5fe82] Turing v0.33.0

daeh avatar Jun 06 '24 22:06 daeh

TBH I'm uncertain if this is intended or not, but I do agree that the kwargs should be treated similarly to the argument.

One way you can easily check what's considered random and what's considered "observed" is to just sample from the model:

rand(Turing.OrderedDict, model)

torfjelde avatar Jun 12 '24 11:06 torfjelde

Btw, we generally recommend using condition instead of passing in observations as model args / kwargs these days. That is, write your model as

@model function gdemo(x, ::Type{T}=Float64) where {T}
    s² ~ InverseGamma(2, 3)
    m ~ Normal(0, sqrt(s²))
    x = Vector{T}(undef, 2)
    for i in eachindex(x)
        x[i] ~ Normal(m, sqrt(s²))
    end
end

model = gdemo()
model_cond = model | (x = x_data,)

Going forward this will be the recommended way of doing things.

torfjelde avatar Jun 12 '24 11:06 torfjelde

Thanks for the quick fix! Yes, I'll start using the condition syntax from here on out. I posted the issue primarily because the kwarg behavior was very unexpected (took me a while to figure out what the issue was), and I imagined it could trip up other Turing newbies too. Thanks!

daeh avatar Jun 14 '24 00:06 daeh

Thank you for bringing up the issue:) Was not aware of this bug until you brought it up.

torfjelde avatar Jun 17 '24 20:06 torfjelde