DynamicPPL.jl icon indicating copy to clipboard operation
DynamicPPL.jl copied to clipboard

InferenceObjects integration

Open sethaxen opened this issue 2 years ago • 12 comments

The plan is to allow AbstractMCMC.sample to return InferenceObjects.InferenceData as a chain_type and to move toward that being a default return type in Turing. There's a mostly functional proof of concept of this integration at https://github.com/sethaxen/DynamicPPLInferenceObjects.jl. @yebai suggested moving this code into DynamicPPL directly and adding InferenceObjects as a dependency, which would increase DynamicPPL load time by 20%. I've opened this issue to discuss whether we want to take this approach or a different one for this integration.

From DynamicPPLInferenceObjects, it seems the integration may be entirely implementable just by overloading methods from DynamicPPL and AbstractMCMC, so alternatively, on Julia v1.9 it could be implemented as an extension, and on early versions it could be loaded with Requires.

Related issues:

  • https://github.com/TuringLang/MCMCChains.jl/issues/381
  • https://github.com/TuringLang/Turing.jl/pull/1913

sethaxen avatar Feb 17 '23 11:02 sethaxen

early versions it could be loaded with Requires.

Could we just make it a proper dependency on older Julia versions? It's still possible to make it a weak dependency on Julia 1.9 at the same time.

devmotion avatar Feb 17 '23 12:02 devmotion

It seems you also implement chainstack and bundle samples? That should rather be an extension of AbstractMCMC, I assume?

devmotion avatar Feb 17 '23 12:02 devmotion

Could we just make it a proper dependency on older Julia versions? It's still possible to make it a weak dependency on Julia 1.9 at the same time.

Ah, really? Is that just making it a fully dependency on 1.9 but just only loading it in an extension? Or something more fancy?

It seems you also implement chainstack and bundle samples? That should rather be an extension of AbstractMCMC, I assume?

For chainstack, yes, but bundle_samples relies on some DynamicPPL functionality. I suppose we can restrict the model type to DynamicPPL.Model to avoid type piracy.

For Chains the corresponding methods live in Turing proper. That might be cleaner. Currently it relies on utility functions get_params and get_sample_stats, which would need to be added to the DynamicPPL API so that Turing could overload them for its samplers.

sethaxen avatar Feb 17 '23 12:02 sethaxen

It's a weak dependency on Julia >= 1.9 if you declare it both as a strong dependency and a weak dependency. See https://pkgdocs.julialang.org/dev/creating-packages/#Transition-from-normal-dependency-to-extension.

devmotion avatar Feb 17 '23 13:02 devmotion

For Chains the corresponding methods live in Turing proper. That might be cleaner.

There have been long discussions (and even issues IIRC, maybe even in DynamicPPL?) about how messy the current situation is - e.g., in many places in DynamicPPL we rely on functionality that is only implemented in MCMCChains but avoid having it as a dependency and instead allow AbstractChains. Similar, I think not all code in Turing should actually be there.

devmotion avatar Feb 17 '23 13:02 devmotion

For Chains the corresponding methods live in Turing proper. That might be cleaner.

There have been long discussions (and even issues IIRC, maybe even in DynamicPPL?) about how messy the current situation is - e.g., in many places in DynamicPPL we rely on functionality that is only implemented in MCMCChains but avoid having it as a dependency and instead allow AbstractChains. Similar, I think not all code in Turing should actually be there.

How would you recommend proceeding then for this integration?

sethaxen avatar Feb 17 '23 14:02 sethaxen

Is there a particular reason why we don't just add it to Turing for now? I agree a week dep might make sense, but it's a bit annoying to make it an explicit dependency pre-1.9, no? In Turing.jl I'm guessing the increased compilation time will be minor in comparison.

torfjelde avatar Feb 17 '23 17:02 torfjelde

Is there a particular reason why we don't just add it to Turing for now?

If the code lived in Turing it would entirely be type piracy. Other than that, I don't see a good reason.

sethaxen avatar Feb 17 '23 17:02 sethaxen

A 20% increase in load time is probably not a big deal, I think.

yebai avatar Feb 17 '23 17:02 yebai

How would you recommend proceeding then for this integration?

I guess short-term basically anything goes - it can just take a long time to move away and improve a suboptimal but somewhat working setup, in my experience.

In the longer term, I think a better solution would be

  • to implement chainstack etc. in the InferenceObjects package itself by depending on AbstractMCMC or, see below, a an even more lightweight chains package (similar to sampler packages)
  • to generalize the methods in DynamicPPL such as pointwise_loglikelihood, loglikelihood etc. in such a way that they can work with arbitrary AbstractChains as input, similar to how arrays with dimensions are supported by MCMCDiagnosticTools
  • to make Turing also AbstractChains/chain_type agnostic

The last two points probably require some additions to the AbstractChains interface (well, there isn't one yet), in AbstractMCMC or some other, even more lightweight, package. For instance, I have thought for a while that something like eachsample, eachchain, etc. (similar to eachslice) could be useful and be used instead of the explicit 1:size(chain, 1) etc. in the current code in DynamicPPL.

devmotion avatar Feb 17 '23 18:02 devmotion

If the code lived in Turing it would entirely be type piracy. Other than that, I don't see a good reason.

That is basically the entirety of Turing.jl though haha :sweat_smile:

I guess short-term basically anything goes - it can just take a long time to move away and improve a suboptimal but somewhat working setup, in my experience.

But this is a fair point :confused: We have quite a lot of examples of that.

torfjelde avatar Feb 17 '23 21:02 torfjelde

I guess short-term basically anything goes - it can just take a long time to move away and improve a suboptimal but somewhat working setup, in my experience.

Okay, I think then the approach I will take in the short term is:

  1. adapt all the code from DynamicPPLInferenceObjects except for the code in bundle_samples.jl to be an extension module DynamicPPLInferenceObjectsExt.
  2. Add the code in bundle_samples.jl to Turing

In the longer term, I think a better solution would be

These all sound good, but at the moment I lack the bandwidth to tackle them.

For instance, I have thought for a while that something like eachsample, eachchain, etc. (similar to eachslice) could be useful and be used instead of the explicit 1:size(chain, 1) etc. in the current code in DynamicPPL.

Agreed! Actually, this would be automatically supported in InferenceObjects once DimensionalData adds eachslice support for AbstractDimStack (see https://github.com/rafaqz/DimensionalData.jl/pull/418), but in the meantime this works:

julia> using InferenceObjects, DimensionalData

julia> function _eachslice(ds::InferenceObjects.Dataset; dims)
           concrete_dims = DimensionalData.dims(ds, dims)
           return (view(ds, d...) for d in DimensionalData.DimIndices(concrete_dims))
       end;

julia> eachchain(ds::InferenceObjects.Dataset) = _eachslice(ds; dims=DimensionalData.Dim{:chain});

julia> function eachsample(ds::InferenceObjects.Dataset)
           sample_dims = (DimensionalData.Dim{:chain}, DimensionalData.Dim{:draw})
           return _eachslice(ds; dims=sample_dims)
       end;

julia> using ArviZExampleData;

julia> ds = load_example_data("centered_eight").posterior
Dataset with dimensions: 
  Dim{:draw} Sampled{Int64} Int64[0, 1, …, 498, 499] ForwardOrdered Irregular Points,
  Dim{:chain} Sampled{Int64} Int64[0, 1, 2, 3] ForwardOrdered Irregular Points,
  Dim{:school} Categorical{String} String[Choate, Deerfield, …, St. Paul's, Mt. Hermon] Unordered
and 3 layers:
  :mu    Float64 dims: Dim{:draw}, Dim{:chain} (500×4)
  :theta Float64 dims: Dim{:school}, Dim{:draw}, Dim{:chain} (8×500×4)
  :tau   Float64 dims: Dim{:draw}, Dim{:chain} (500×4)

with metadata Dict{String, Any} with 6 entries:
  "created_at"                => "2022-10-13T14:37:37.315398"
  "inference_library_version" => "4.2.2"
  "sampling_time"             => 7.48011
  "tuning_steps"              => 1000
  "arviz_version"             => "0.13.0.dev0"
  "inference_library"         => "pymc"

julia> collect(eachchain(ds))[1]
Dataset with dimensions: 
  Dim{:draw} Sampled{Int64} Int64[0, 1, …, 498, 499] ForwardOrdered Irregular Points,
  Dim{:school} Categorical{String} String[Choate, Deerfield, …, St. Paul's, Mt. Hermon] Unordered
and 3 layers:
  :mu    Float64 dims: Dim{:draw} (500)
  :theta Float64 dims: Dim{:school}, Dim{:draw} (8×500)
  :tau   Float64 dims: Dim{:draw} (500)

with metadata Dict{String, Any} with 6 entries:
  "created_at"                => "2022-10-13T14:37:37.315398"
  "inference_library_version" => "4.2.2"
  "sampling_time"             => 7.48011
  "tuning_steps"              => 1000
  "arviz_version"             => "0.13.0.dev0"
  "inference_library"         => "pymc"

julia> collect(eachsample(ds))[1]
Dataset with dimensions: 
  Dim{:school} Categorical{String} String[Choate, Deerfield, …, St. Paul's, Mt. Hermon] Unordered
and 3 layers:
  :mu    Float64 dims: 
  :theta Float64 dims: Dim{:school} (8)
  :tau   Float64 dims: 

with metadata Dict{String, Any} with 6 entries:
  "created_at"                => "2022-10-13T14:37:37.315398"
  "inference_library_version" => "4.2.2"
  "sampling_time"             => 7.48011
  "tuning_steps"              => 1000
  "arviz_version"             => "0.13.0.dev0"
  "inference_library"         => "pymc"

Edit: adapted for https://github.com/rafaqz/DimensionalData.jl/pull/462

sethaxen avatar Feb 17 '23 21:02 sethaxen