InferenceObjects integration
The plan is to allow AbstractMCMC.sample to return InferenceObjects.InferenceData as a chain_type and to move toward that being a default return type in Turing. There's a mostly functional proof of concept of this integration at https://github.com/sethaxen/DynamicPPLInferenceObjects.jl. @yebai suggested moving this code into DynamicPPL directly and adding InferenceObjects as a dependency, which would increase DynamicPPL load time by 20%. I've opened this issue to discuss whether we want to take this approach or a different one for this integration.
From DynamicPPLInferenceObjects, it seems the integration may be entirely implementable just by overloading methods from DynamicPPL and AbstractMCMC, so alternatively, on Julia v1.9 it could be implemented as an extension, and on early versions it could be loaded with Requires.
Related issues:
- https://github.com/TuringLang/MCMCChains.jl/issues/381
- https://github.com/TuringLang/Turing.jl/pull/1913
early versions it could be loaded with Requires.
Could we just make it a proper dependency on older Julia versions? It's still possible to make it a weak dependency on Julia 1.9 at the same time.
It seems you also implement chainstack and bundle samples? That should rather be an extension of AbstractMCMC, I assume?
Could we just make it a proper dependency on older Julia versions? It's still possible to make it a weak dependency on Julia 1.9 at the same time.
Ah, really? Is that just making it a fully dependency on 1.9 but just only loading it in an extension? Or something more fancy?
It seems you also implement chainstack and bundle samples? That should rather be an extension of AbstractMCMC, I assume?
For chainstack, yes, but bundle_samples relies on some DynamicPPL functionality. I suppose we can restrict the model type to DynamicPPL.Model to avoid type piracy.
For Chains the corresponding methods live in Turing proper. That might be cleaner. Currently it relies on utility functions get_params and get_sample_stats, which would need to be added to the DynamicPPL API so that Turing could overload them for its samplers.
It's a weak dependency on Julia >= 1.9 if you declare it both as a strong dependency and a weak dependency. See https://pkgdocs.julialang.org/dev/creating-packages/#Transition-from-normal-dependency-to-extension.
For Chains the corresponding methods live in Turing proper. That might be cleaner.
There have been long discussions (and even issues IIRC, maybe even in DynamicPPL?) about how messy the current situation is - e.g., in many places in DynamicPPL we rely on functionality that is only implemented in MCMCChains but avoid having it as a dependency and instead allow AbstractChains. Similar, I think not all code in Turing should actually be there.
For Chains the corresponding methods live in Turing proper. That might be cleaner.
There have been long discussions (and even issues IIRC, maybe even in DynamicPPL?) about how messy the current situation is - e.g., in many places in DynamicPPL we rely on functionality that is only implemented in MCMCChains but avoid having it as a dependency and instead allow AbstractChains. Similar, I think not all code in Turing should actually be there.
How would you recommend proceeding then for this integration?
Is there a particular reason why we don't just add it to Turing for now? I agree a week dep might make sense, but it's a bit annoying to make it an explicit dependency pre-1.9, no? In Turing.jl I'm guessing the increased compilation time will be minor in comparison.
Is there a particular reason why we don't just add it to Turing for now?
If the code lived in Turing it would entirely be type piracy. Other than that, I don't see a good reason.
A 20% increase in load time is probably not a big deal, I think.
How would you recommend proceeding then for this integration?
I guess short-term basically anything goes - it can just take a long time to move away and improve a suboptimal but somewhat working setup, in my experience.
In the longer term, I think a better solution would be
- to implement
chainstacketc. in the InferenceObjects package itself by depending on AbstractMCMC or, see below, a an even more lightweight chains package (similar to sampler packages) - to generalize the methods in DynamicPPL such as
pointwise_loglikelihood,loglikelihoodetc. in such a way that they can work with arbitraryAbstractChains as input, similar to how arrays with dimensions are supported by MCMCDiagnosticTools - to make Turing also
AbstractChains/chain_typeagnostic
The last two points probably require some additions to the AbstractChains interface (well, there isn't one yet), in AbstractMCMC or some other, even more lightweight, package.
For instance, I have thought for a while that something like eachsample, eachchain, etc. (similar to eachslice) could be useful and be used instead of the explicit 1:size(chain, 1) etc. in the current code in DynamicPPL.
If the code lived in Turing it would entirely be type piracy. Other than that, I don't see a good reason.
That is basically the entirety of Turing.jl though haha :sweat_smile:
I guess short-term basically anything goes - it can just take a long time to move away and improve a suboptimal but somewhat working setup, in my experience.
But this is a fair point :confused: We have quite a lot of examples of that.
I guess short-term basically anything goes - it can just take a long time to move away and improve a suboptimal but somewhat working setup, in my experience.
Okay, I think then the approach I will take in the short term is:
- adapt all the code from DynamicPPLInferenceObjects except for the code in
bundle_samples.jlto be an extension moduleDynamicPPLInferenceObjectsExt. - Add the code in
bundle_samples.jlto Turing
In the longer term, I think a better solution would be
These all sound good, but at the moment I lack the bandwidth to tackle them.
For instance, I have thought for a while that something like
eachsample,eachchain, etc. (similar toeachslice) could be useful and be used instead of the explicit1:size(chain, 1)etc. in the current code in DynamicPPL.
Agreed! Actually, this would be automatically supported in InferenceObjects once DimensionalData adds eachslice support for AbstractDimStack (see https://github.com/rafaqz/DimensionalData.jl/pull/418), but in the meantime this works:
julia> using InferenceObjects, DimensionalData
julia> function _eachslice(ds::InferenceObjects.Dataset; dims)
concrete_dims = DimensionalData.dims(ds, dims)
return (view(ds, d...) for d in DimensionalData.DimIndices(concrete_dims))
end;
julia> eachchain(ds::InferenceObjects.Dataset) = _eachslice(ds; dims=DimensionalData.Dim{:chain});
julia> function eachsample(ds::InferenceObjects.Dataset)
sample_dims = (DimensionalData.Dim{:chain}, DimensionalData.Dim{:draw})
return _eachslice(ds; dims=sample_dims)
end;
julia> using ArviZExampleData;
julia> ds = load_example_data("centered_eight").posterior
Dataset with dimensions:
Dim{:draw} Sampled{Int64} Int64[0, 1, …, 498, 499] ForwardOrdered Irregular Points,
Dim{:chain} Sampled{Int64} Int64[0, 1, 2, 3] ForwardOrdered Irregular Points,
Dim{:school} Categorical{String} String[Choate, Deerfield, …, St. Paul's, Mt. Hermon] Unordered
and 3 layers:
:mu Float64 dims: Dim{:draw}, Dim{:chain} (500×4)
:theta Float64 dims: Dim{:school}, Dim{:draw}, Dim{:chain} (8×500×4)
:tau Float64 dims: Dim{:draw}, Dim{:chain} (500×4)
with metadata Dict{String, Any} with 6 entries:
"created_at" => "2022-10-13T14:37:37.315398"
"inference_library_version" => "4.2.2"
"sampling_time" => 7.48011
"tuning_steps" => 1000
"arviz_version" => "0.13.0.dev0"
"inference_library" => "pymc"
julia> collect(eachchain(ds))[1]
Dataset with dimensions:
Dim{:draw} Sampled{Int64} Int64[0, 1, …, 498, 499] ForwardOrdered Irregular Points,
Dim{:school} Categorical{String} String[Choate, Deerfield, …, St. Paul's, Mt. Hermon] Unordered
and 3 layers:
:mu Float64 dims: Dim{:draw} (500)
:theta Float64 dims: Dim{:school}, Dim{:draw} (8×500)
:tau Float64 dims: Dim{:draw} (500)
with metadata Dict{String, Any} with 6 entries:
"created_at" => "2022-10-13T14:37:37.315398"
"inference_library_version" => "4.2.2"
"sampling_time" => 7.48011
"tuning_steps" => 1000
"arviz_version" => "0.13.0.dev0"
"inference_library" => "pymc"
julia> collect(eachsample(ds))[1]
Dataset with dimensions:
Dim{:school} Categorical{String} String[Choate, Deerfield, …, St. Paul's, Mt. Hermon] Unordered
and 3 layers:
:mu Float64 dims:
:theta Float64 dims: Dim{:school} (8)
:tau Float64 dims:
with metadata Dict{String, Any} with 6 entries:
"created_at" => "2022-10-13T14:37:37.315398"
"inference_library_version" => "4.2.2"
"sampling_time" => 7.48011
"tuning_steps" => 1000
"arviz_version" => "0.13.0.dev0"
"inference_library" => "pymc"
Edit: adapted for https://github.com/rafaqz/DimensionalData.jl/pull/462