probability icon indicating copy to clipboard operation
probability copied to clipboard

Add experimental `GibbsKernel`

Open chrism0dwk opened this issue 4 years ago • 4 comments

GibbsKernel is a tfp.mcmc.TransitionKernel that allows a sequence of transitional kernels to be bundled together to sequentially update different parts of a full state. Currently the semantic is to map one state part per Gibbs step, though this is definitely up for discussion.

GibbsKernel currently refers to each state part by index. This leads to poor readability of the code, particularly in large Gibbs schemes where the target_log_prob_fn definition scrolls off the top of the page. The time seems right to PR this kernel, as discussions around turnkey MCMC solutions are progressing!

chrism0dwk avatar Feb 07 '21 14:02 chrism0dwk

Thanks, Chris! I started taking a look at this, but feel free to ping in a few days if there are no comments.

ColCarroll avatar Feb 09 '21 16:02 ColCarroll

@ColCarroll no worries! I feel a Gibbs sampling capability is something badly needed in TFP currently. Although the HMC (and derivatives) literature focuses on single-kernel chains, the wider class of models involving discontinuous log_prob functions is important -- not least in current event-time analysis in the COVID epidemic.

I should also provide a description of the logic of GibbsKernel:

  1. GibbsKernel is essentially a meta-kernel which sequentially calls instances of tfp.mcmc.TransitionKernel on different state parts;
  2. GibbsKernelResults tracks the value of target_log_prob (on the target scale!), whilst keeping the individual steps' "results" structures stored in a list.
  3. GibbsKernel.one_step sequentially forwards target_log_prob to each previous_results structure before passing it to the respective Gibbs step's one_step function.
  4. Instances of tfp.mcmc.TransformedTransitionKernel are detected, and the bijector is invoked in order to map target_log_prob --> transformed_target_log_prob.
  5. The mixing of non-gradient based kernels (e.g. RWM) with gradient-based kernels poses an issue with forwarding gradient information. Currently, we take the pragmatic approach of calling k.bootstrap_results to reset target_log_prob and associated gradients if k's previous_results contains gradient information. Doubtless we could be cleverer!
  6. A design issue with this kernel is the presence of a Python list within the GibbsKernelResults named_tuple. This breaks the recursive functions such as get_innermost and get_outermost. Ideas welcome.

chrism0dwk avatar Feb 10 '21 12:02 chrism0dwk

@ColCarroll ping

coatless avatar Feb 17 '21 01:02 coatless

A few thoughts after looking over this:

  1. If you had a complete conditional (which we should for the MVN case?), how would you implement the make_kernel_fn? Could make_kernel_fn return a tfd.Distribution, and if we get back a distribution we literally call dist.sample(seed=seed)? Or could we add a SamplingKernel that takes a Distribution and one_step just returns another sample? Can you add this as an example?
  2. Could you add an example that incorporates a discrete latent, and showcases using a different kernel for a different state part (MH for discrete, HMC for continuous)?
  3. It is a bit limiting that we only deal with single state parts at a time, though I guess that could be generalized later. Also a bit limiting to address them by index (though I guess that's more a limitation of the overall TFP MCMC stack supporting only lists and tuples).
  4. The gradient resetting stuff is icky (but we do similar things in REMC).
  5. I don't quite follow the necessity of unwrapping TransformedTransitionKernel instances. Can you explain why that's necessary? We have some cases of joint distributions where the descendent event space bijectors depend on the ancestral sample values. e.g:
    @tfd.JointDistributionCoroutine
    def model():
       low = yield tfd.Normal(0, 1)
       span = yield tfd.LogNormal(0, 1)
       x = yield tfd.Uniform(low, low+span)
       y = yield tfd.Normal(x, 1)
    
    In this case the joint bijector constraining x to its support needs the values of low and low+span to shift and scale the sigmoid bijection. My main pt here is that sometimes TTK can act across state parts and not only within a single state part.
  6. Could we add support for "make_kernel_for_the_rest_of_the_states" where only the explicitly specified indices get a make_kernel_fn, and everything else can use (e.g.) a single TTK(HMC(..))? I guess this makes sense to defer to the future as well, alongside (3) above.
  7. Can you clarify: What about list and unnest causes issues?

brianwa84 avatar Mar 02 '21 15:03 brianwa84