DynamicPPL.jl
DynamicPPL.jl copied to clipboard
Give user an option to use `SimpleVarInfo` with `sample` function
Partially address https://github.com/TuringLang/Turing.jl/issues/2213
An example
julia> using AbstractMCMC, DynamicPPL
[...]
julia> model = DynamicPPL.TestUtils.DEMO_MODELS[1]
Model{typeof(DynamicPPL.TestUtils.demo_dot_assume_dot_observe), (:x, Symbol("##arg#289")), (), (), Tuple{Vector{Float64}, DynamicPPL.TypeWrap{Vector{Float64}}}, Tuple{}, DefaultContext}(DynamicPPL.TestUtils.demo_dot_assume_dot_observe, (x = [1.5, 2.0], var"##arg#289" = DynamicPPL.TypeWrap{Vector{Float64}}()), NamedTuple(), DefaultContext())
julia> chn = sample(model, SampleFromUniform(), 10; trace_type = SimpleVarInfo)
10-element Vector{SimpleVarInfo{OrderedDict{Any, Any}, Float64, DynamicPPL.NoTransformation}}:
SimpleVarInfo(OrderedDict{Any, Any}(s[1] => 0.18821524284155777, s[2] => 0.33957731437677985, m[1] => 0.027047762098432387, m[2] => -0.3396883816169604), -27.049779424102084)
...
Work with Turing
This should work with Turing's Inference
pipeline with almost no modification, the only change is https://github.com/TuringLang/Turing.jl/blob/56f64ec5909cec4a5ded4e28555c2b289020bbe1/src/mcmc/Inference.jl#L319 to
function getparams(model::DynamicPPL.Model, vi::Union{DynamicPPL.VarInfo, DynamicPPL.SimpleVarInfo})
This allows bundle_samples
to use this function.
Then
julia> AbstractMCMC.step(Random.default_rng(), model, DynamicPPL.Sampler(HMC(0.2, 20), DynamicPPL.Selector()); trace_type = SimpleVarInfo)
(Turing.Inference.Transition{Vector{Tuple{AbstractPPL.VarName{sym, Accessors.IndexLens{Tuple{Int64}}} where sym, Float64}}, Float64, @NamedTuple{n_steps::Int64, is_accept::Bool, acceptance_rate::Float64, log_density::Float64, hamiltonian_energy::Float64, hamiltonian_energy_error::Float64, numerical_error::Bool, step_size::Float64, nom_step_size::Float64}}(Tuple{AbstractPPL.VarName{sym, Accessors.IndexLens{Tuple{Int64}}} where sym, Float64}[(s[1], 1.6822421472438154), (s[2], 0.8921514354736135), (m[1], -0.1272569385613846), (m[2], 0.8103126419880976)], -7.598386060870171, (n_steps = 20, is_accept = true, acceptance_rate = 1.0, log_density = -7.598386060870171, hamiltonian_energy = 9.707595087582115, hamiltonian_energy_error = -0.0094109681431096, numerical_error = true, step_size = 0.2, nom_step_size = 0.2)), Turing.Inference.HMCState{DynamicPPL.SimpleVarInfo{OrderedCollections.OrderedDict{Any, Float64}, Float64, DynamicPPL.DynamicTransformation}, AdvancedHMC.HMCKernel{AdvancedHMC.FullMomentumRefreshment, AdvancedHMC.Trajectory{AdvancedHMC.EndPointTS, AdvancedHMC.Leapfrog{Float64}, AdvancedHMC.FixedNSteps}}, AdvancedHMC.Hamiltonian{AdvancedHMC.UnitEuclideanMetric{Float64, Tuple{Int64}}, AdvancedHMC.GaussianKinetic, Base.Fix1{typeof(LogDensityProblems.logdensity), LogDensityProblemsADForwardDiffExt.ForwardDiffLogDensity{LogDensityFunction{DynamicPPL.SimpleVarInfo{OrderedCollections.OrderedDict{Any, Any}, Float64, DynamicPPL.DynamicTransformation}, DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_dot_assume_dot_observe), (:x, Symbol("##arg#289")), (), (), Tuple{Vector{Float64}, DynamicPPL.TypeWrap{Vector{Float64}}}, Tuple{}, DynamicPPL.DefaultContext}, DynamicPPL.SamplingContext{DynamicPPL.Sampler{HMC{AutoForwardDiff{nothing, Nothing}, (), AdvancedHMC.UnitEuclideanMetric}}, DynamicPPL.DefaultContext, TaskLocalRNG}}, ForwardDiff.Chunk{4}, ForwardDiff.Tag{DynamicPPL.DynamicPPLTag, Float64}, ForwardDiff.GradientConfig{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag, Float64}, Float64, 4, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag, Float64}, Float64, 4}}}}}, Turing.Inference.var"#∂logπ∂θ#32"{LogDensityProblemsADForwardDiffExt.ForwardDiffLogDensity{LogDensityFunction{DynamicPPL.SimpleVarInfo{OrderedCollections.OrderedDict{Any, Any}, Float64, DynamicPPL.DynamicTransformation}, DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_dot_assume_dot_observe), (:x, Symbol("##arg#289")), (), (), Tuple{Vector{Float64}, DynamicPPL.TypeWrap{Vector{Float64}}}, Tuple{}, DynamicPPL.DefaultContext}, DynamicPPL.SamplingContext{DynamicPPL.Sampler{HMC{AutoForwardDiff{nothing, Nothing}, (), AdvancedHMC.UnitEuclideanMetric}}, DynamicPPL.DefaultContext, TaskLocalRNG}}, ForwardDiff.Chunk{4}, ForwardDiff.Tag{DynamicPPL.DynamicPPLTag, Float64}, ForwardDiff.GradientConfig{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag, Float64}, Float64, 4, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag, Float64}, Float64, 4}}}}}}, AdvancedHMC.PhasePoint{Vector{Float64}, AdvancedHMC.DualValue{Float64, Vector{Float64}}}, AdvancedHMC.Adaptation.NoAdaptation}(DynamicPPL.SimpleVarInfo{OrderedCollections.OrderedDict{Any, Float64}, Float64, DynamicPPL.DynamicTransformation}(OrderedCollections.OrderedDict{Any, Float64}(s[1] => 0.5201275150675577, s[2] => -0.11411939010121486, m[1] => -0.1272569385613846, m[2] => 0.8103126419880976), -7.598386060870171, DynamicPPL.DynamicTransformation()), 1, AdvancedHMC.HMCKernel{AdvancedHMC.FullMomentumRefreshment, AdvancedHMC.Trajectory{AdvancedHMC.EndPointTS, AdvancedHMC.Leapfrog{Float64}, AdvancedHMC.FixedNSteps}}(AdvancedHMC.FullMomentumRefreshment(), Trajectory{AdvancedHMC.EndPointTS}(integrator=Leapfrog(ϵ=0.2), tc=AdvancedHMC.FixedNSteps(20))), Hamiltonian(metric=UnitEuclideanMetric([1.0, 1.0, 1.0, 1.0]), kinetic=AdvancedHMC.GaussianKinetic()), AdvancedHMC.PhasePoint{Vector{Float64}, AdvancedHMC.DualValue{Float64, Vector{Float64}}}([0.5201275150675577, -0.11411939010121486, -0.1272569385613846, 0.8103126419880976], [0.2195756027868936, 2.0379584520488434, 0.11023899654296329, 0.06911815570849916], AdvancedHMC.DualValue{Float64, Vector{Float64}}(-7.598386060870171, [0.4248179767985423, -1.5238746846234323, -1.042961549856163, -0.4252357850238818]), AdvancedHMC.DualValue{Float64, Vector{Float64}}(-2.109209026711945, [-0.2195756027868936, -2.0379584520488434, -0.11023899654296329, -0.06911815570849916])), AdvancedHMC.Adaptation.NoAdaptation()))
julia> chn = sample(model, HMC(0.2, 20), 10; trace_type = SimpleVarInfo)
Chains MCMC chain (10×14×1 Array{Float64, 3}):
Iterations = 1:1:10
Number of chains = 1
Samples per chain = 10
Wall duration = 2.04 seconds
Compute duration = 2.04 seconds
parameters = s[1], s[2], m[1], m[2]
internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, numerical_error, step_size, nom_step_size
Summary Statistics
parameters mean std mcse ess_bulk ess_tail rhat ess_per_sec
Symbol Float64 Float64 Float64 Float64 Float64 Float64 Float64
s[1] 1.6551 0.9727 0.3219 8.1555 10.0000 1.6048 3.9939
s[2] 0.9875 0.1954 0.0942 4.5670 10.0000 1.5820 2.2365
m[1] 0.7053 1.0723 0.4889 5.0832 10.0000 1.4156 2.4893
m[2] 1.1302 0.5511 0.1743 10.0000 10.0000 0.9406 4.8972
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
s[1] 0.6404 0.9317 1.5316 2.0851 3.5047
s[2] 0.7580 0.8095 1.0099 1.1126 1.3119
m[1] -0.9806 -0.1058 1.0414 1.3320 2.2166
m[2] 0.3819 0.7973 1.0316 1.6741 1.8228