DynamicPPL.jl
DynamicPPL.jl copied to clipboard
Give user an option to use `SimpleVarInfo` with `sample` function
Partially address https://github.com/TuringLang/Turing.jl/issues/2213
An example
julia> using AbstractMCMC, DynamicPPL
[...]
julia> model = DynamicPPL.TestUtils.DEMO_MODELS[1]
Model{typeof(DynamicPPL.TestUtils.demo_dot_assume_dot_observe), (:x, Symbol("##arg#289")), (), (), Tuple{Vector{Float64}, DynamicPPL.TypeWrap{Vector{Float64}}}, Tuple{}, DefaultContext}(DynamicPPL.TestUtils.demo_dot_assume_dot_observe, (x = [1.5, 2.0], var"##arg#289" = DynamicPPL.TypeWrap{Vector{Float64}}()), NamedTuple(), DefaultContext())
julia> chn = sample(model, SampleFromUniform(), 10; trace_type = SimpleVarInfo)
10-element Vector{SimpleVarInfo{OrderedDict{Any, Any}, Float64, DynamicPPL.NoTransformation}}:
SimpleVarInfo(OrderedDict{Any, Any}(s[1] => 0.18821524284155777, s[2] => 0.33957731437677985, m[1] => 0.027047762098432387, m[2] => -0.3396883816169604), -27.049779424102084)
...
Work with Turing
This should work with Turing's Inference pipeline with almost no modification, the only change is https://github.com/TuringLang/Turing.jl/blob/56f64ec5909cec4a5ded4e28555c2b289020bbe1/src/mcmc/Inference.jl#L319 to
function getparams(model::DynamicPPL.Model, vi::Union{DynamicPPL.VarInfo, DynamicPPL.SimpleVarInfo})
This allows bundle_samples to use this function.
Then
julia> AbstractMCMC.step(Random.default_rng(), model, DynamicPPL.Sampler(HMC(0.2, 20), DynamicPPL.Selector()); trace_type = SimpleVarInfo)
(Turing.Inference.Transition{Vector{Tuple{AbstractPPL.VarName{sym, Accessors.IndexLens{Tuple{Int64}}} where sym, Float64}}, Float64, @NamedTuple{n_steps::Int64, is_accept::Bool, acceptance_rate::Float64, log_density::Float64, hamiltonian_energy::Float64, hamiltonian_energy_error::Float64, numerical_error::Bool, step_size::Float64, nom_step_size::Float64}}(Tuple{AbstractPPL.VarName{sym, Accessors.IndexLens{Tuple{Int64}}} where sym, Float64}[(s[1], 1.6822421472438154), (s[2], 0.8921514354736135), (m[1], -0.1272569385613846), (m[2], 0.8103126419880976)], -7.598386060870171, (n_steps = 20, is_accept = true, acceptance_rate = 1.0, log_density = -7.598386060870171, hamiltonian_energy = 9.707595087582115, hamiltonian_energy_error = -0.0094109681431096, numerical_error = true, step_size = 0.2, nom_step_size = 0.2)), Turing.Inference.HMCState{DynamicPPL.SimpleVarInfo{OrderedCollections.OrderedDict{Any, Float64}, Float64, DynamicPPL.DynamicTransformation}, AdvancedHMC.HMCKernel{AdvancedHMC.FullMomentumRefreshment, AdvancedHMC.Trajectory{AdvancedHMC.EndPointTS, AdvancedHMC.Leapfrog{Float64}, AdvancedHMC.FixedNSteps}}, AdvancedHMC.Hamiltonian{AdvancedHMC.UnitEuclideanMetric{Float64, Tuple{Int64}}, AdvancedHMC.GaussianKinetic, Base.Fix1{typeof(LogDensityProblems.logdensity), LogDensityProblemsADForwardDiffExt.ForwardDiffLogDensity{LogDensityFunction{DynamicPPL.SimpleVarInfo{OrderedCollections.OrderedDict{Any, Any}, Float64, DynamicPPL.DynamicTransformation}, DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_dot_assume_dot_observe), (:x, Symbol("##arg#289")), (), (), Tuple{Vector{Float64}, DynamicPPL.TypeWrap{Vector{Float64}}}, Tuple{}, DynamicPPL.DefaultContext}, DynamicPPL.SamplingContext{DynamicPPL.Sampler{HMC{AutoForwardDiff{nothing, Nothing}, (), AdvancedHMC.UnitEuclideanMetric}}, DynamicPPL.DefaultContext, TaskLocalRNG}}, ForwardDiff.Chunk{4}, ForwardDiff.Tag{DynamicPPL.DynamicPPLTag, Float64}, ForwardDiff.GradientConfig{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag, Float64}, Float64, 4, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag, Float64}, Float64, 4}}}}}, Turing.Inference.var"#∂logπ∂θ#32"{LogDensityProblemsADForwardDiffExt.ForwardDiffLogDensity{LogDensityFunction{DynamicPPL.SimpleVarInfo{OrderedCollections.OrderedDict{Any, Any}, Float64, DynamicPPL.DynamicTransformation}, DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_dot_assume_dot_observe), (:x, Symbol("##arg#289")), (), (), Tuple{Vector{Float64}, DynamicPPL.TypeWrap{Vector{Float64}}}, Tuple{}, DynamicPPL.DefaultContext}, DynamicPPL.SamplingContext{DynamicPPL.Sampler{HMC{AutoForwardDiff{nothing, Nothing}, (), AdvancedHMC.UnitEuclideanMetric}}, DynamicPPL.DefaultContext, TaskLocalRNG}}, ForwardDiff.Chunk{4}, ForwardDiff.Tag{DynamicPPL.DynamicPPLTag, Float64}, ForwardDiff.GradientConfig{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag, Float64}, Float64, 4, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag, Float64}, Float64, 4}}}}}}, AdvancedHMC.PhasePoint{Vector{Float64}, AdvancedHMC.DualValue{Float64, Vector{Float64}}}, AdvancedHMC.Adaptation.NoAdaptation}(DynamicPPL.SimpleVarInfo{OrderedCollections.OrderedDict{Any, Float64}, Float64, DynamicPPL.DynamicTransformation}(OrderedCollections.OrderedDict{Any, Float64}(s[1] => 0.5201275150675577, s[2] => -0.11411939010121486, m[1] => -0.1272569385613846, m[2] => 0.8103126419880976), -7.598386060870171, DynamicPPL.DynamicTransformation()), 1, AdvancedHMC.HMCKernel{AdvancedHMC.FullMomentumRefreshment, AdvancedHMC.Trajectory{AdvancedHMC.EndPointTS, AdvancedHMC.Leapfrog{Float64}, AdvancedHMC.FixedNSteps}}(AdvancedHMC.FullMomentumRefreshment(), Trajectory{AdvancedHMC.EndPointTS}(integrator=Leapfrog(ϵ=0.2), tc=AdvancedHMC.FixedNSteps(20))), Hamiltonian(metric=UnitEuclideanMetric([1.0, 1.0, 1.0, 1.0]), kinetic=AdvancedHMC.GaussianKinetic()), AdvancedHMC.PhasePoint{Vector{Float64}, AdvancedHMC.DualValue{Float64, Vector{Float64}}}([0.5201275150675577, -0.11411939010121486, -0.1272569385613846, 0.8103126419880976], [0.2195756027868936, 2.0379584520488434, 0.11023899654296329, 0.06911815570849916], AdvancedHMC.DualValue{Float64, Vector{Float64}}(-7.598386060870171, [0.4248179767985423, -1.5238746846234323, -1.042961549856163, -0.4252357850238818]), AdvancedHMC.DualValue{Float64, Vector{Float64}}(-2.109209026711945, [-0.2195756027868936, -2.0379584520488434, -0.11023899654296329, -0.06911815570849916])), AdvancedHMC.Adaptation.NoAdaptation()))
julia> chn = sample(model, HMC(0.2, 20), 10; trace_type = SimpleVarInfo)
Chains MCMC chain (10×14×1 Array{Float64, 3}):
Iterations = 1:1:10
Number of chains = 1
Samples per chain = 10
Wall duration = 2.04 seconds
Compute duration = 2.04 seconds
parameters = s[1], s[2], m[1], m[2]
internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, numerical_error, step_size, nom_step_size
Summary Statistics
parameters mean std mcse ess_bulk ess_tail rhat ess_per_sec
Symbol Float64 Float64 Float64 Float64 Float64 Float64 Float64
s[1] 1.6551 0.9727 0.3219 8.1555 10.0000 1.6048 3.9939
s[2] 0.9875 0.1954 0.0942 4.5670 10.0000 1.5820 2.2365
m[1] 0.7053 1.0723 0.4889 5.0832 10.0000 1.4156 2.4893
m[2] 1.1302 0.5511 0.1743 10.0000 10.0000 0.9406 4.8972
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
s[1] 0.6404 0.9317 1.5316 2.0851 3.5047
s[2] 0.7580 0.8095 1.0099 1.1126 1.3119
m[1] -0.9806 -0.1058 1.0414 1.3320 2.2166
m[2] 0.3819 0.7973 1.0316 1.6741 1.8228
An alternative approach I can envision is fully adopting LogDensityFunction in Turing through the ExternalSampler interface, but this might requires much more serious work encapsulating the InferenceAlgorithms. (Do we have plan to do this during the coming months?)
Also for SimpleVarInfo, I opted in OrderedDict. To use NamedTuple, it requires to predetermine the variable names (ref https://github.com/TuringLang/DynamicPPL.jl/blob/48487cc471543030699e3a39776dd939756d0165/src/simple_varinfo.jl#L62-L65), in principle can be done, but need a bit of work.
This also means, when SimpleVarInfo through the changes proposed in this PR may be less performant.
@torfjelde @yebai @devmotion does this PR make sense? If this is desirable, then I'll extend tests in https://github.com/TuringLang/DynamicPPL.jl/blob/master/test/sampler.jl.
Pull Request Test Coverage Report for Build 9092154133
Details
- 9 of 17 (52.94%) changed or added relevant lines in 2 files are covered.
- 2 unchanged lines in 1 file lost coverage.
- Overall coverage decreased (-1.0%) to 77.471%
| Changes Missing Coverage | Covered Lines | Changed/Added Lines | % |
|---|---|---|---|
| src/simple_varinfo.jl | 0 | 2 | 0.0% |
| src/sampler.jl | 9 | 15 | 60.0% |
| <!-- | Total: | 9 | 17 |
| Files with Coverage Reduction | New Missed Lines | % |
|---|---|---|
| src/sampler.jl | 2 | 84.75% |
| <!-- | Total: | 2 |
| Totals | |
|---|---|
| Change from base Build 9062140401: | -1.0% |
| Covered Lines: | 2775 |
| Relevant Lines: | 3582 |
💛 - Coveralls
Pull Request Test Coverage Report for Build 9116442883
Details
- 17 of 20 (85.0%) changed or added relevant lines in 2 files are covered.
- No unchanged relevant lines lost coverage.
- Overall coverage increased (+0.03%) to 77.596%
| Changes Missing Coverage | Covered Lines | Changed/Added Lines | % |
|---|---|---|---|
| src/simple_varinfo.jl | 0 | 1 | 0.0% |
| src/sampler.jl | 17 | 19 | 89.47% |
| <!-- | Total: | 17 | 20 |
| Totals | |
|---|---|
| Change from base Build 9099752668: | 0.03% |
| Covered Lines: | 2660 |
| Relevant Lines: | 3428 |
💛 - Coveralls
cc @willtebbutt
@willtebbutt can you help review this PR?
Having a look @sunxd3 :+1:
@torfjelde is my understanding of SimpleVarInfo with NamedTuple correct at https://github.com/TuringLang/DynamicPPL.jl/pull/606#issuecomment-2111867290?
we need to make it possible to use the NamedTuple version
I agree.
Other than performance, I thought SimpleVarInfo is also less error-prone for AD (correct me if wrong), but I am unsure if AbstractDict version of SimpleVarInfo works better than VarInfo.
s my understanding of SimpleVarInfo with NamedTuple correct at
Yep! If you don't "seed" SimpleVarInfo{<:NamedTuple} with the correct values, then it will only be sensible for models containing only varnames of the form VarName{sym,typeof(identity)}. Now that we have "debugging" capabilities, we could "check" this, but that would ofc not be 100% reliable.
Other than performance, I thought SimpleVarInfo is also less error-prone for AD (correct me if wrong)
Not really. Or rather, I don't think VarInfo is particuarly error-prone for AD either. SimpleVarInfo{<:NamedTuple} might be for some of the more recent AD backends, e.g. Tapir.jl and Enzyme.jl, but only because it improves type-stability (which is not hte case of SimpleVarInfo{<:Dict})
Let me look into it and see if I can make NamedTuple variant of SimpleVarInfo work, or at least a clear TODOs
Let me look into it and see if I can make NamedTuple variant of SimpleVarInfo work, or at least a clear TODOs
Lovely:) And I don't mean to be negative about this btw. I'm just not a huge fan of adding additional kwargs, etc. unless there's a clear reason to use them (because otherwise nobody will ever use them until they suddenly are riddled with uncaught bugs because nobody uses it). So if we're going to add an additional kwarg to use SimpleVarInfo, we should simultaneously prove that it has utility, i.e.:
- Make it work with
SimpleVarInfo{<:NamedTuple}, because that definitively has utility. - Test it properly in Turing.jl and make sure that indeed
SimpleVarInfois used, and nowhere do we implicitly convert toVarInfobefore we merge this into DynamicPPL.jl.
Sounds good. I started this as a quick and dirty prototype, it definitely needs more work, until we can justify complicating the interface.
This is no longer necessary since Tapir now works with all tracing types.