AbstractMCMC.jl icon indicating copy to clipboard operation
AbstractMCMC.jl copied to clipboard

Add some interface functions to support the new Gibbs sampler in Turing

Open sunxd3 opened this issue 7 months ago • 20 comments

The recent new Gibbs sampler provides a way forward for the Turing inference stack.

A near-to-medium-range goal has been to further reduce the glue code between Turing and inference packages (ref https://github.com/TuringLang/Turing.jl/issues/2281). The new Gibbs implementation laid a great plan to achieve this goal.

This PR is modeled after the interface of @torfjelde's recent PR. And in some aspects, it is a rehash of https://github.com/TuringLang/AbstractMCMC.jl/pull/86.

(the explanation here is outdated, please refer to https://github.com/TuringLang/AbstractMCMC.jl/pull/144#issuecomment-2337681868)

~~The goal of this PR is to determine and implement some necessary interface improvements, so that, when we update the inference packages up to the interface, they will more or less "just work" with the new Gibbs implementation.~~

~~As a first step, we test-flight two new functions recompute_logprob!!(rng, model, sampler, state) and getparams(state):~~

  • ~~recompute_logprob!!(rng, model, sampler, state) recomputes the logprob given the state~~
  • ~~getparams(state) extract the parameter values~~

~~Some considerations:~~

  • ~~This assumes a state is implemented with AbstractMCMC compatible inference packages. And a state at least stores values of parameters from the current iteration (traditionally, this is in the form of a Transition) and logprob.~~
  • ~~recompute_logprob!!(rng, model, sampler, state)~~
    • ~~do we need rng?~~
    • ~~should we make model into AbstractMCMC.LogDensityModel or just LogDensityProblem (and make inference packages depend on LogDensityProblems in the latter case)? This should allow inference packages to be independent from DynamicPPL, we can use getparams to construct a varinfo in Turing~~
  • ~~getparams(state) ~~
    • ~~What does this function return? A vector, a transition?~~
    • ~~Do we need setparams?~~
  • ~~Do we also need some interface functions for state like getstats?~~

~~Tor also says (in a Slack conversation) that the a condition(model, params) is needed, but better to be implemented by packages that defines the model, which I agree.~~

sunxd3 avatar Jul 12 '24 15:07 sunxd3