DynamicPPL.jl icon indicating copy to clipboard operation
DynamicPPL.jl copied to clipboard

Give user an option to use `SimpleVarInfo` with `sample` function

Open sunxd3 opened this issue 1 year ago • 13 comments

Partially address https://github.com/TuringLang/Turing.jl/issues/2213

An example

julia> using AbstractMCMC, DynamicPPL
[...]

julia> model = DynamicPPL.TestUtils.DEMO_MODELS[1]
Model{typeof(DynamicPPL.TestUtils.demo_dot_assume_dot_observe), (:x, Symbol("##arg#289")), (), (), Tuple{Vector{Float64}, DynamicPPL.TypeWrap{Vector{Float64}}}, Tuple{}, DefaultContext}(DynamicPPL.TestUtils.demo_dot_assume_dot_observe, (x = [1.5, 2.0], var"##arg#289" = DynamicPPL.TypeWrap{Vector{Float64}}()), NamedTuple(), DefaultContext())

julia> chn = sample(model, SampleFromUniform(), 10; trace_type = SimpleVarInfo)
10-element Vector{SimpleVarInfo{OrderedDict{Any, Any}, Float64, DynamicPPL.NoTransformation}}:
 SimpleVarInfo(OrderedDict{Any, Any}(s[1] => 0.18821524284155777, s[2] => 0.33957731437677985, m[1] => 0.027047762098432387, m[2] => -0.3396883816169604), -27.049779424102084)
 ...

Work with Turing

This should work with Turing's Inference pipeline with almost no modification, the only change is https://github.com/TuringLang/Turing.jl/blob/56f64ec5909cec4a5ded4e28555c2b289020bbe1/src/mcmc/Inference.jl#L319 to

function getparams(model::DynamicPPL.Model, vi::Union{DynamicPPL.VarInfo, DynamicPPL.SimpleVarInfo})

This allows bundle_samples to use this function.

Then

julia> AbstractMCMC.step(Random.default_rng(), model, DynamicPPL.Sampler(HMC(0.2, 20), DynamicPPL.Selector()); trace_type = SimpleVarInfo)
(Turing.Inference.Transition{Vector{Tuple{AbstractPPL.VarName{sym, Accessors.IndexLens{Tuple{Int64}}} where sym, Float64}}, Float64, @NamedTuple{n_steps::Int64, is_accept::Bool, acceptance_rate::Float64, log_density::Float64, hamiltonian_energy::Float64, hamiltonian_energy_error::Float64, numerical_error::Bool, step_size::Float64, nom_step_size::Float64}}(Tuple{AbstractPPL.VarName{sym, Accessors.IndexLens{Tuple{Int64}}} where sym, Float64}[(s[1], 1.6822421472438154), (s[2], 0.8921514354736135), (m[1], -0.1272569385613846), (m[2], 0.8103126419880976)], -7.598386060870171, (n_steps = 20, is_accept = true, acceptance_rate = 1.0, log_density = -7.598386060870171, hamiltonian_energy = 9.707595087582115, hamiltonian_energy_error = -0.0094109681431096, numerical_error = true, step_size = 0.2, nom_step_size = 0.2)), Turing.Inference.HMCState{DynamicPPL.SimpleVarInfo{OrderedCollections.OrderedDict{Any, Float64}, Float64, DynamicPPL.DynamicTransformation}, AdvancedHMC.HMCKernel{AdvancedHMC.FullMomentumRefreshment, AdvancedHMC.Trajectory{AdvancedHMC.EndPointTS, AdvancedHMC.Leapfrog{Float64}, AdvancedHMC.FixedNSteps}}, AdvancedHMC.Hamiltonian{AdvancedHMC.UnitEuclideanMetric{Float64, Tuple{Int64}}, AdvancedHMC.GaussianKinetic, Base.Fix1{typeof(LogDensityProblems.logdensity), LogDensityProblemsADForwardDiffExt.ForwardDiffLogDensity{LogDensityFunction{DynamicPPL.SimpleVarInfo{OrderedCollections.OrderedDict{Any, Any}, Float64, DynamicPPL.DynamicTransformation}, DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_dot_assume_dot_observe), (:x, Symbol("##arg#289")), (), (), Tuple{Vector{Float64}, DynamicPPL.TypeWrap{Vector{Float64}}}, Tuple{}, DynamicPPL.DefaultContext}, DynamicPPL.SamplingContext{DynamicPPL.Sampler{HMC{AutoForwardDiff{nothing, Nothing}, (), AdvancedHMC.UnitEuclideanMetric}}, DynamicPPL.DefaultContext, TaskLocalRNG}}, ForwardDiff.Chunk{4}, ForwardDiff.Tag{DynamicPPL.DynamicPPLTag, Float64}, ForwardDiff.GradientConfig{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag, Float64}, Float64, 4, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag, Float64}, Float64, 4}}}}}, Turing.Inference.var"#∂logπ∂θ#32"{LogDensityProblemsADForwardDiffExt.ForwardDiffLogDensity{LogDensityFunction{DynamicPPL.SimpleVarInfo{OrderedCollections.OrderedDict{Any, Any}, Float64, DynamicPPL.DynamicTransformation}, DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_dot_assume_dot_observe), (:x, Symbol("##arg#289")), (), (), Tuple{Vector{Float64}, DynamicPPL.TypeWrap{Vector{Float64}}}, Tuple{}, DynamicPPL.DefaultContext}, DynamicPPL.SamplingContext{DynamicPPL.Sampler{HMC{AutoForwardDiff{nothing, Nothing}, (), AdvancedHMC.UnitEuclideanMetric}}, DynamicPPL.DefaultContext, TaskLocalRNG}}, ForwardDiff.Chunk{4}, ForwardDiff.Tag{DynamicPPL.DynamicPPLTag, Float64}, ForwardDiff.GradientConfig{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag, Float64}, Float64, 4, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag, Float64}, Float64, 4}}}}}}, AdvancedHMC.PhasePoint{Vector{Float64}, AdvancedHMC.DualValue{Float64, Vector{Float64}}}, AdvancedHMC.Adaptation.NoAdaptation}(DynamicPPL.SimpleVarInfo{OrderedCollections.OrderedDict{Any, Float64}, Float64, DynamicPPL.DynamicTransformation}(OrderedCollections.OrderedDict{Any, Float64}(s[1] => 0.5201275150675577, s[2] => -0.11411939010121486, m[1] => -0.1272569385613846, m[2] => 0.8103126419880976), -7.598386060870171, DynamicPPL.DynamicTransformation()), 1, AdvancedHMC.HMCKernel{AdvancedHMC.FullMomentumRefreshment, AdvancedHMC.Trajectory{AdvancedHMC.EndPointTS, AdvancedHMC.Leapfrog{Float64}, AdvancedHMC.FixedNSteps}}(AdvancedHMC.FullMomentumRefreshment(), Trajectory{AdvancedHMC.EndPointTS}(integrator=Leapfrog(ϵ=0.2), tc=AdvancedHMC.FixedNSteps(20))), Hamiltonian(metric=UnitEuclideanMetric([1.0, 1.0, 1.0, 1.0]), kinetic=AdvancedHMC.GaussianKinetic()), AdvancedHMC.PhasePoint{Vector{Float64}, AdvancedHMC.DualValue{Float64, Vector{Float64}}}([0.5201275150675577, -0.11411939010121486, -0.1272569385613846, 0.8103126419880976], [0.2195756027868936, 2.0379584520488434, 0.11023899654296329, 0.06911815570849916], AdvancedHMC.DualValue{Float64, Vector{Float64}}(-7.598386060870171, [0.4248179767985423, -1.5238746846234323, -1.042961549856163, -0.4252357850238818]), AdvancedHMC.DualValue{Float64, Vector{Float64}}(-2.109209026711945, [-0.2195756027868936, -2.0379584520488434, -0.11023899654296329, -0.06911815570849916])), AdvancedHMC.Adaptation.NoAdaptation()))

julia> chn = sample(model, HMC(0.2, 20), 10; trace_type = SimpleVarInfo)
Chains MCMC chain (10×14×1 Array{Float64, 3}):

Iterations        = 1:1:10
Number of chains  = 1
Samples per chain = 10
Wall duration     = 2.04 seconds
Compute duration  = 2.04 seconds
parameters        = s[1], s[2], m[1], m[2]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   ess_per_sec 
      Symbol   Float64   Float64   Float64    Float64    Float64   Float64       Float64 

        s[1]    1.6551    0.9727    0.3219     8.1555    10.0000    1.6048        3.9939
        s[2]    0.9875    0.1954    0.0942     4.5670    10.0000    1.5820        2.2365
        m[1]    0.7053    1.0723    0.4889     5.0832    10.0000    1.4156        2.4893
        m[2]    1.1302    0.5511    0.1743    10.0000    10.0000    0.9406        4.8972

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

        s[1]    0.6404    0.9317    1.5316    2.0851    3.5047
        s[2]    0.7580    0.8095    1.0099    1.1126    1.3119
        m[1]   -0.9806   -0.1058    1.0414    1.3320    2.2166
        m[2]    0.3819    0.7973    1.0316    1.6741    1.8228

sunxd3 avatar May 15 '24 08:05 sunxd3

An alternative approach I can envision is fully adopting LogDensityFunction in Turing through the ExternalSampler interface, but this might requires much more serious work encapsulating the InferenceAlgorithms. (Do we have plan to do this during the coming months?)

Also for SimpleVarInfo, I opted in OrderedDict. To use NamedTuple, it requires to predetermine the variable names (ref https://github.com/TuringLang/DynamicPPL.jl/blob/48487cc471543030699e3a39776dd939756d0165/src/simple_varinfo.jl#L62-L65), in principle can be done, but need a bit of work. This also means, when SimpleVarInfo through the changes proposed in this PR may be less performant.

sunxd3 avatar May 15 '24 08:05 sunxd3

@torfjelde @yebai @devmotion does this PR make sense? If this is desirable, then I'll extend tests in https://github.com/TuringLang/DynamicPPL.jl/blob/master/test/sampler.jl.

sunxd3 avatar May 15 '24 08:05 sunxd3

Pull Request Test Coverage Report for Build 9092154133

Details

  • 9 of 17 (52.94%) changed or added relevant lines in 2 files are covered.
  • 2 unchanged lines in 1 file lost coverage.
  • Overall coverage decreased (-1.0%) to 77.471%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/simple_varinfo.jl 0 2 0.0%
src/sampler.jl 9 15 60.0%
<!-- Total: 9 17
Files with Coverage Reduction New Missed Lines %
src/sampler.jl 2 84.75%
<!-- Total: 2
Totals Coverage Status
Change from base Build 9062140401: -1.0%
Covered Lines: 2775
Relevant Lines: 3582

💛 - Coveralls

coveralls avatar May 15 '24 08:05 coveralls

Pull Request Test Coverage Report for Build 9116442883

Details

  • 17 of 20 (85.0%) changed or added relevant lines in 2 files are covered.
  • No unchanged relevant lines lost coverage.
  • Overall coverage increased (+0.03%) to 77.596%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/simple_varinfo.jl 0 1 0.0%
src/sampler.jl 17 19 89.47%
<!-- Total: 17 20
Totals Coverage Status
Change from base Build 9099752668: 0.03%
Covered Lines: 2660
Relevant Lines: 3428

💛 - Coveralls

coveralls avatar May 15 '24 08:05 coveralls

cc @willtebbutt

yebai avatar May 15 '24 08:05 yebai

@willtebbutt can you help review this PR?

yebai avatar May 16 '24 11:05 yebai

Having a look @sunxd3 :+1:

torfjelde avatar May 16 '24 12:05 torfjelde

@torfjelde is my understanding of SimpleVarInfo with NamedTuple correct at https://github.com/TuringLang/DynamicPPL.jl/pull/606#issuecomment-2111867290?

sunxd3 avatar May 16 '24 17:05 sunxd3

we need to make it possible to use the NamedTuple version

I agree.

Other than performance, I thought SimpleVarInfo is also less error-prone for AD (correct me if wrong), but I am unsure if AbstractDict version of SimpleVarInfo works better than VarInfo.

sunxd3 avatar May 16 '24 17:05 sunxd3

s my understanding of SimpleVarInfo with NamedTuple correct at

Yep! If you don't "seed" SimpleVarInfo{<:NamedTuple} with the correct values, then it will only be sensible for models containing only varnames of the form VarName{sym,typeof(identity)}. Now that we have "debugging" capabilities, we could "check" this, but that would ofc not be 100% reliable.

Other than performance, I thought SimpleVarInfo is also less error-prone for AD (correct me if wrong)

Not really. Or rather, I don't think VarInfo is particuarly error-prone for AD either. SimpleVarInfo{<:NamedTuple} might be for some of the more recent AD backends, e.g. Tapir.jl and Enzyme.jl, but only because it improves type-stability (which is not hte case of SimpleVarInfo{<:Dict})

torfjelde avatar May 16 '24 18:05 torfjelde

Let me look into it and see if I can make NamedTuple variant of SimpleVarInfo work, or at least a clear TODOs

sunxd3 avatar May 17 '24 08:05 sunxd3

Let me look into it and see if I can make NamedTuple variant of SimpleVarInfo work, or at least a clear TODOs

Lovely:) And I don't mean to be negative about this btw. I'm just not a huge fan of adding additional kwargs, etc. unless there's a clear reason to use them (because otherwise nobody will ever use them until they suddenly are riddled with uncaught bugs because nobody uses it). So if we're going to add an additional kwarg to use SimpleVarInfo, we should simultaneously prove that it has utility, i.e.:

  1. Make it work with SimpleVarInfo{<:NamedTuple}, because that definitively has utility.
  2. Test it properly in Turing.jl and make sure that indeed SimpleVarInfo is used, and nowhere do we implicitly convert to VarInfo before we merge this into DynamicPPL.jl.

torfjelde avatar May 17 '24 08:05 torfjelde

Sounds good. I started this as a quick and dirty prototype, it definitely needs more work, until we can justify complicating the interface.

sunxd3 avatar May 17 '24 09:05 sunxd3

This is no longer necessary since Tapir now works with all tracing types.

yebai avatar Aug 06 '24 18:08 yebai