Turing.jl icon indicating copy to clipboard operation
Turing.jl copied to clipboard

Better support for AbstractSampler

Open torfjelde opened this issue 1 year ago • 26 comments

Now that we have DynamicPPL.LogDensityFunction + AbstractMCMC.LogDensityModel, it would be nice to drop the usage of Turing.Inference.InferenceAlgorithm and allow "arbitrary" AbstractMCMC.AbstractSampler to be used with in Turing without much effort on the user side.

This would of course also help alleviate the discrepancy in configuration that is present on the package containing the implementation of the sampler and the corresponding "inference algorithm" in Turing.jl.

The issue with just replacing InferenceAlgorithm with AbstractSampler and AbstractMCMC.AbstractModel with DynamicPPL.Model, as is done in #2008, is that of method-ambiguities. I also assume that method ambiguities were part of the reason why InferenceAlgorithm was kept around.

There are, as I see it, two solutions:

  1. Replace the currently present overloads of AbstractMCMC.sample in Turing.jl by Turing's own sample, i.e. make Turing.sample !== AbstractMCMC.sample. Turing.sample then calls AbstractMCMC.sample under the hood but with some additional niceties.
  2. Keep something equivalent to InferenceAlgorithm or a wrapper type and continue overloading AbstractMCMC.sample.

Personally, I'm very much in favour of option (1). This has several benefits:

  1. Removes the type-piracy.
  2. Maintains the status-quo.

torfjelde avatar Jun 15 '23 12:06 torfjelde

@devmotion @yebai @JaimeRZP thoughts?

torfjelde avatar Jun 15 '23 12:06 torfjelde

Hi Tor!

Thanks for opening an issue. After some thought I am also in favour of 1; i.e. Turing.sample . I think besides the points you have raised I can also see how it would be easier to debug in case the user manages to dispatch to the wrong function. Also for some samplers Turing dispatches to sampler specific instance AbstractMCMC.sample. Having all of these being specified within Turing seems the most elegant solution.

What do others think?

Best, Jaime

JaimeRZP avatar Jun 15 '23 13:06 JaimeRZP

I have been playing around with these ideas and I have noticed that even if we get rid of InferenceAlgorithm would still need to wrap AbstractSampler in DynamicPPL.Sampler() if we do not wish to change many things in one single PR. Would somethign like the code below also lead to method ambiguities?

function AbstractMCMC.sample(
    rng::AbstractRNG,
    model::Model,
    sampler::Sampler{<:AbstractSampler},
    N::Integer;
    chain_type=MCMCChains.Chains,
    resume_from=nothing,
    progress=PROGRESS[],
    kwargs...
)
.
.
.
end

JaimeRZP avatar Jun 16 '23 13:06 JaimeRZP

Quick note: If we do this we will have the same type ambiguity that we would have in sample in bundle_samples.

JaimeRZP avatar Jun 16 '23 16:06 JaimeRZP

One consequence of (1) is that many users will run into warnings about conflicting exports and many scripts and tutorials will can't be used with newer Turing versions - AbstractMCMC.sample itself is not owned by AbstractMCMC but just StatsBase.sample. (1) might still be preferable but that's a somewhat unfortunate consequence of it, I think.

I wonder if there is a better way to fix these method ambiguities. What's the main problem to fix here? I think the problem is mainly that https://github.com/TuringLang/Turing.jl/pull/2008 does keep sample definitions with generic sampler and DynamicPPL.Model - even though I think they should be removed. If a sampler package defines sample, Turing should be able to reuse it automatically without any custom definitions if this implementation is defined for generic LogDensityModels.

devmotion avatar Jun 16 '23 20:06 devmotion

What's the main problem to fix here?

Problem is: how do we avoid method ambiguities while at the same time allowing Turing.sample (whatever this method is) to add Turing's bells and whistles, e.g.

  • Good support for conversion into a Chains, i.e. if chain_type=Chains (which it should also be by default in Turing.sample) then we should extract parameter names, etc.
  • Convert DynamicPPL.Model into a LogDensityFunction or whatever the actual sampler wants.

I think the problem is mainly that #2008 does keep sample definitions with generic sampler and DynamicPPL.Model - even though I think they should be removed.

So before that was InferenceAlgorithm + AbstractModel. I def agree that we should remove these overloads or something, but we want to do that while we preserve the current behavior of Turing.sample, right?

If a sampler package defines sample, Turing should be able to reuse it automatically without any custom definitions if this implementation is defined for generic LogDensityModels.

Which def are you referring to?

torfjelde avatar Jun 16 '23 20:06 torfjelde

One consequence of (1) is that many users will run into warnings about conflicting exports and many scripts and tutorials will can't be used with newer Turing versions - AbstractMCMC.sample itself is not owned by AbstractMCMC but just StatsBase.sample.

Very much agree with this though :confused:

torfjelde avatar Jun 16 '23 20:06 torfjelde

while at the same time allowing Turing.sample (whatever this method is) to add Turing's bells and whistles, e.g.

My premise was that these would have to be removed or integrated in AbstractMCMC (possibly not the exact default values but some way to specify them).

Convert DynamicPPL.Model into a LogDensityFunction or whatever the actual sampler wants.

If the sampler expects a LogDensityModel, shouldn't this work automatically?

devmotion avatar Jun 16 '23 21:06 devmotion

My premise was that these would have to be removed or integrated in AbstractMCMC (possibly not the exact default values but some way to specify them).

But how do we do either of these while still preserving the desired behavior mentioned above? :confused:

Complete removal of Turing.sample implementations seems ideal, but I don't understand how we can do that and preserve the status quo from a Turing-user's persective :confused:

If the sampler expects a LogDensityModel, shouldn't this work automatically?

I don't think I fully understand :confused: Can you write the signature you want Turing.sample to have (and whether it's AbstractMCMC or not)?

torfjelde avatar Jun 16 '23 22:06 torfjelde

I assumed that Model <: AbstractMCMC.LogDensityModel. Isn't that what we want? Then sample(rng, ::LogDensityModel, ::MyCustomSampler) would be sufficient (modulo Turing-specific defaults). But it seems currently it's just a subtype of AbstractMCMC.AbstractModel (https://github.com/TuringLang/AbstractPPL.jl/blob/0f289206b22da5feee03218c157afb28706ed8ec/src/abstractprobprog.jl#L11).

devmotion avatar Jun 16 '23 22:06 devmotion

I assumed that Model <: AbstractMCMC.LogDensityModel. Isn't that what we want?

Uhmm that is not at all what I hand in mind :sweat_smile: But I don't quite see how this helps. Sure, it does let us remove the sample functions from Turing, but we're also then losing everything that comes with that.

Then sample(rng, ::LogDensityModel, ::MyCustomSampler) would be sufficient (modulo Turing-specific defaults).

But the Turing-speciific stuff is sort of the important bit, no? :sweat_smile: Otherwise, yes things become much simpler.

torfjelde avatar Jun 16 '23 22:06 torfjelde

My feeling was that they can easily be kept by changing the upstream code a bit. After all, it seems the only Turing-specific keyword arguments are chain_type = Chains and resume_from? Support for the latter could be added in AbstractMCMC I think, and the chain_type default could be made a function of the model (and sampler, I assume? - but then we need clear instructions for who is allowed to implement it to avoid method ambiguities, I guess).

devmotion avatar Jun 17 '23 06:06 devmotion

But how do we extract the variable names in this scenario? This also needs to enter somewhere in the process, right?

chain_type default could be made a function of the model (and sampler, I assume? - but then we need clear instructions for who is allowed to implement it to avoid method ambiguities, I guess).

But would that help us in this case? We would dispatch on ::Model and ::AbstractSampler then, which would still leave room for ambiguity?

torfjelde avatar Jun 17 '23 11:06 torfjelde

The variable names? Isn't that done by save!! and bundle_samples?

We can also prescribe the priority of arguments by defining eg

default_chain_type(model, sampler) = default_chain_type(sampler)
default_chain_type(sampler) = Any

and demanding that sampler packages only implement default_chain_type(::MySampler).

But, honestly, maybe it would be best to remove chain_type completely. I've never liked it very much, it doesn't seem a good approach for achieving composability. Maybe instead we should just demand that sample returns something that satisfies a Chains-interface (similar to Tables.jl, with some defaults for base types like Vector{<:NamedTuple}). And then Chains pacakges could define their constructors/converters for anything that satisfies this interface.

devmotion avatar Jun 17 '23 12:06 devmotion

The variable names? Isn't that done by save!! and bundle_samples?

Aye, but on what do we dispatch bundle_samples if we remove InferenceAlgorithm?

But, honestly, maybe it would be best to remove chain_type completely. I've never liked it very much, it doesn't seem a good approach for achieving composability. Maybe instead we should just demand that sample returns something that satisfies a Chains-interface (similar to Tables.jl, with some defaults for base types like Vector{<:NamedTuple}). And then Chains pacakges could define their constructors/converters for anything that satisfies this interface.

Very happy to go down this route! Though it will require some more effort on our part. I do agree with the comment on chains_type too.

EDIT: The effort-comment isn't to discourage the idea, just to comment on the effort vs. the above proposals. But as we've said, the above proposals are much worse alterantives.

torfjelde avatar Jun 17 '23 12:06 torfjelde

Aye, but on what do we dispatch bundle_samples if we remove InferenceAlgorithm?

On DynamicPPL.Model? As suggested in the previous comment, we could work around ambiguity issues by not dispatching on the model in the default fallback and when implementing it in the sampler packages but only in e.g. Turing.

devmotion avatar Jun 17 '23 12:06 devmotion

As suggested in the previous comment, we could work around ambiguity issues by not dispatching on the model in the default fallback and when implementing it in the sampler packages but only in e.g. Turing

Ah, I see. Hmm, I don't know man, that seems somewhat fickle :confused: All it takes is one sampler package to add a ::AbstractModel in the bundle_samples def, no?

torfjelde avatar Jun 17 '23 13:06 torfjelde

If the official API/docs state that you must not do this, it's a bug in the sampler package :shrug:

devmotion avatar Jun 17 '23 13:06 devmotion

Of course, it's just that it's something that's difficult to properly test for since involves combinations of packages which are supposed to be independent of each other.

But let's take an example. In AdvancedHMC.jl we have the following implementation in the MCMCChains extension:

function AbstractMCMC.bundle_samples(
    ts::Vector{<:Transition},
    model::AbstractMCMC.AbstractModel,
    sampler::AbstractMCMC.AbstractSampler,
    state,
    chain_type::Type{Chains};
    kwargs...,
)

Here we would just remove the specialization on mode::AbstractMCMC.AbstractModel, right? Making it

function AbstractMCMC.bundle_samples(
    ts::Vector{<:Transition},
    model,
    sampler::AbstractMCMC.AbstractSampler,
    state,
    chain_type::Type{Chains};
    kwargs...,
)

Then in Turing.jl, we'd implement what exactly? Clearly just specializing on the Model won't be sufficient here, right?

torfjelde avatar Jun 17 '23 14:06 torfjelde

No, I meant a different approach - the one outlined for default_chain_type above.

devmotion avatar Jun 17 '23 15:06 devmotion

and demanding that sampler packages only implement default_chain_type(::MySampler).

Aaaah okay I forgot this. But then I don't understand how you can hook into, say, MCMCChains as a sampler :confused:

EDIT: Are you also saying that we shouldn't hook into MCMCChains if you're sampler package?

torfjelde avatar Jun 17 '23 15:06 torfjelde

You can define default_chain_type(::MySampler) = Chains, can't you? But since it's just a default value anyway and users are free to specify any other value, my feeling is that it does not scale if sampler packages specialize too much on a single chains format - which would again be an argument for decoupling the chains output completely from the sampling.

devmotion avatar Jun 17 '23 18:06 devmotion

You can define default_chain_type(::MySampler) = Chains, can't you?

But that's not enough. For example, in the above example, bundle_samples does stuff like transforming the variables back to the constrained space, extract statistics and putting that in a format that is compatible with Chains, etc. For this you then need to overload bundle_samples, right?

torfjelde avatar Jun 17 '23 18:06 torfjelde

transforming the variables back to the constrained space

Maybe a side remark but should that really be done in bundle_samples? Such transformations mean also that the iterator and transducer interface are inconsistent with sample.

I think if the samples are a collection of draws, then bundle_samples should merely add additional metadata and convert it to the desired output type. And the latter I think we should remove; and maybe we could just return (samples, metadata) or (; samples, metadata) by default.

If the chains packages implement constructors based on these return types, users could use

sample(model, sampler, N; ....) |> Chains

or

Chains(sample(model, sampler, N; ...))

devmotion avatar Jun 17 '23 19:06 devmotion

Maybe a side remark but should that really be done in bundle_samples? Such transformations mean also that the iterator and transducer interface are inconsistent with sample.

I agree that it's not desirable, but it does have the advantage that it defers transforming the variables back to the constrained space just before you pass it to the user, i.e. AdvancedHMC doesn't have to worry about the transformations at all.

torfjelde avatar Jun 17 '23 19:06 torfjelde

And the latter I think we should remove; and maybe we could just return (samples, metadata) or (; samples, metadata) by default.

Okay, then I'm with you on what you mean:) I also like the "manual" part of piping the result to Chains or something, but I don't like the idea of the end-user of Turing, which is suppposed to be somewhat accessible for someone with minimal programmming experience, to have to go through this manual construction. For everywhere else, I very much agree that manual specification would be nice as it would allow us to clean up this process significantly.

torfjelde avatar Jun 17 '23 19:06 torfjelde