DynamicPPL.jl icon indicating copy to clipboard operation
DynamicPPL.jl copied to clipboard

Give user an option to use `SimpleVarInfo` with `sample` function

Open sunxd3 opened this issue 9 months ago • 13 comments

Partially address https://github.com/TuringLang/Turing.jl/issues/2213

An example

julia> using AbstractMCMC, DynamicPPL
[...]

julia> model = DynamicPPL.TestUtils.DEMO_MODELS[1]
Model{typeof(DynamicPPL.TestUtils.demo_dot_assume_dot_observe), (:x, Symbol("##arg#289")), (), (), Tuple{Vector{Float64}, DynamicPPL.TypeWrap{Vector{Float64}}}, Tuple{}, DefaultContext}(DynamicPPL.TestUtils.demo_dot_assume_dot_observe, (x = [1.5, 2.0], var"##arg#289" = DynamicPPL.TypeWrap{Vector{Float64}}()), NamedTuple(), DefaultContext())

julia> chn = sample(model, SampleFromUniform(), 10; trace_type = SimpleVarInfo)
10-element Vector{SimpleVarInfo{OrderedDict{Any, Any}, Float64, DynamicPPL.NoTransformation}}:
 SimpleVarInfo(OrderedDict{Any, Any}(s[1] => 0.18821524284155777, s[2] => 0.33957731437677985, m[1] => 0.027047762098432387, m[2] => -0.3396883816169604), -27.049779424102084)
 ...

Work with Turing

This should work with Turing's Inference pipeline with almost no modification, the only change is https://github.com/TuringLang/Turing.jl/blob/56f64ec5909cec4a5ded4e28555c2b289020bbe1/src/mcmc/Inference.jl#L319 to

function getparams(model::DynamicPPL.Model, vi::Union{DynamicPPL.VarInfo, DynamicPPL.SimpleVarInfo})

This allows bundle_samples to use this function.

Then

julia> AbstractMCMC.step(Random.default_rng(), model, DynamicPPL.Sampler(HMC(0.2, 20), DynamicPPL.Selector()); trace_type = SimpleVarInfo)
(Turing.Inference.Transition{Vector{Tuple{AbstractPPL.VarName{sym, Accessors.IndexLens{Tuple{Int64}}} where sym, Float64}}, Float64, @NamedTuple{n_steps::Int64, is_accept::Bool, acceptance_rate::Float64, log_density::Float64, hamiltonian_energy::Float64, hamiltonian_energy_error::Float64, numerical_error::Bool, step_size::Float64, nom_step_size::Float64}}(Tuple{AbstractPPL.VarName{sym, Accessors.IndexLens{Tuple{Int64}}} where sym, Float64}[(s[1], 1.6822421472438154), (s[2], 0.8921514354736135), (m[1], -0.1272569385613846), (m[2], 0.8103126419880976)], -7.598386060870171, (n_steps = 20, is_accept = true, acceptance_rate = 1.0, log_density = -7.598386060870171, hamiltonian_energy = 9.707595087582115, hamiltonian_energy_error = -0.0094109681431096, numerical_error = true, step_size = 0.2, nom_step_size = 0.2)), Turing.Inference.HMCState{DynamicPPL.SimpleVarInfo{OrderedCollections.OrderedDict{Any, Float64}, Float64, DynamicPPL.DynamicTransformation}, AdvancedHMC.HMCKernel{AdvancedHMC.FullMomentumRefreshment, AdvancedHMC.Trajectory{AdvancedHMC.EndPointTS, AdvancedHMC.Leapfrog{Float64}, AdvancedHMC.FixedNSteps}}, AdvancedHMC.Hamiltonian{AdvancedHMC.UnitEuclideanMetric{Float64, Tuple{Int64}}, AdvancedHMC.GaussianKinetic, Base.Fix1{typeof(LogDensityProblems.logdensity), LogDensityProblemsADForwardDiffExt.ForwardDiffLogDensity{LogDensityFunction{DynamicPPL.SimpleVarInfo{OrderedCollections.OrderedDict{Any, Any}, Float64, DynamicPPL.DynamicTransformation}, DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_dot_assume_dot_observe), (:x, Symbol("##arg#289")), (), (), Tuple{Vector{Float64}, DynamicPPL.TypeWrap{Vector{Float64}}}, Tuple{}, DynamicPPL.DefaultContext}, DynamicPPL.SamplingContext{DynamicPPL.Sampler{HMC{AutoForwardDiff{nothing, Nothing}, (), AdvancedHMC.UnitEuclideanMetric}}, DynamicPPL.DefaultContext, TaskLocalRNG}}, ForwardDiff.Chunk{4}, ForwardDiff.Tag{DynamicPPL.DynamicPPLTag, Float64}, ForwardDiff.GradientConfig{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag, Float64}, Float64, 4, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag, Float64}, Float64, 4}}}}}, Turing.Inference.var"#∂logπ∂θ#32"{LogDensityProblemsADForwardDiffExt.ForwardDiffLogDensity{LogDensityFunction{DynamicPPL.SimpleVarInfo{OrderedCollections.OrderedDict{Any, Any}, Float64, DynamicPPL.DynamicTransformation}, DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_dot_assume_dot_observe), (:x, Symbol("##arg#289")), (), (), Tuple{Vector{Float64}, DynamicPPL.TypeWrap{Vector{Float64}}}, Tuple{}, DynamicPPL.DefaultContext}, DynamicPPL.SamplingContext{DynamicPPL.Sampler{HMC{AutoForwardDiff{nothing, Nothing}, (), AdvancedHMC.UnitEuclideanMetric}}, DynamicPPL.DefaultContext, TaskLocalRNG}}, ForwardDiff.Chunk{4}, ForwardDiff.Tag{DynamicPPL.DynamicPPLTag, Float64}, ForwardDiff.GradientConfig{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag, Float64}, Float64, 4, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag, Float64}, Float64, 4}}}}}}, AdvancedHMC.PhasePoint{Vector{Float64}, AdvancedHMC.DualValue{Float64, Vector{Float64}}}, AdvancedHMC.Adaptation.NoAdaptation}(DynamicPPL.SimpleVarInfo{OrderedCollections.OrderedDict{Any, Float64}, Float64, DynamicPPL.DynamicTransformation}(OrderedCollections.OrderedDict{Any, Float64}(s[1] => 0.5201275150675577, s[2] => -0.11411939010121486, m[1] => -0.1272569385613846, m[2] => 0.8103126419880976), -7.598386060870171, DynamicPPL.DynamicTransformation()), 1, AdvancedHMC.HMCKernel{AdvancedHMC.FullMomentumRefreshment, AdvancedHMC.Trajectory{AdvancedHMC.EndPointTS, AdvancedHMC.Leapfrog{Float64}, AdvancedHMC.FixedNSteps}}(AdvancedHMC.FullMomentumRefreshment(), Trajectory{AdvancedHMC.EndPointTS}(integrator=Leapfrog(ϵ=0.2), tc=AdvancedHMC.FixedNSteps(20))), Hamiltonian(metric=UnitEuclideanMetric([1.0, 1.0, 1.0, 1.0]), kinetic=AdvancedHMC.GaussianKinetic()), AdvancedHMC.PhasePoint{Vector{Float64}, AdvancedHMC.DualValue{Float64, Vector{Float64}}}([0.5201275150675577, -0.11411939010121486, -0.1272569385613846, 0.8103126419880976], [0.2195756027868936, 2.0379584520488434, 0.11023899654296329, 0.06911815570849916], AdvancedHMC.DualValue{Float64, Vector{Float64}}(-7.598386060870171, [0.4248179767985423, -1.5238746846234323, -1.042961549856163, -0.4252357850238818]), AdvancedHMC.DualValue{Float64, Vector{Float64}}(-2.109209026711945, [-0.2195756027868936, -2.0379584520488434, -0.11023899654296329, -0.06911815570849916])), AdvancedHMC.Adaptation.NoAdaptation()))

julia> chn = sample(model, HMC(0.2, 20), 10; trace_type = SimpleVarInfo)
Chains MCMC chain (10×14×1 Array{Float64, 3}):

Iterations        = 1:1:10
Number of chains  = 1
Samples per chain = 10
Wall duration     = 2.04 seconds
Compute duration  = 2.04 seconds
parameters        = s[1], s[2], m[1], m[2]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   ess_per_sec 
      Symbol   Float64   Float64   Float64    Float64    Float64   Float64       Float64 

        s[1]    1.6551    0.9727    0.3219     8.1555    10.0000    1.6048        3.9939
        s[2]    0.9875    0.1954    0.0942     4.5670    10.0000    1.5820        2.2365
        m[1]    0.7053    1.0723    0.4889     5.0832    10.0000    1.4156        2.4893
        m[2]    1.1302    0.5511    0.1743    10.0000    10.0000    0.9406        4.8972

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

        s[1]    0.6404    0.9317    1.5316    2.0851    3.5047
        s[2]    0.7580    0.8095    1.0099    1.1126    1.3119
        m[1]   -0.9806   -0.1058    1.0414    1.3320    2.2166
        m[2]    0.3819    0.7973    1.0316    1.6741    1.8228

sunxd3 avatar May 15 '24 08:05 sunxd3