Add experimental `GibbsKernel`
GibbsKernel is a tfp.mcmc.TransitionKernel that allows a sequence
of transitional kernels to be bundled together to sequentially update
different parts of a full state. Currently the semantic is to map
one state part per Gibbs step, though this is definitely up for
discussion.
GibbsKernel currently refers to each state part by index. This
leads to poor readability of the code, particularly in large Gibbs
schemes where the target_log_prob_fn definition scrolls off the top
of the page. The time seems right to PR this kernel, as discussions
around turnkey MCMC solutions are progressing!
Thanks, Chris! I started taking a look at this, but feel free to ping in a few days if there are no comments.
@ColCarroll no worries! I feel a Gibbs sampling capability is something badly needed in TFP currently. Although the HMC (and derivatives) literature focuses on single-kernel chains, the wider class of models involving discontinuous log_prob functions is important -- not least in current event-time analysis in the COVID epidemic.
I should also provide a description of the logic of GibbsKernel:
GibbsKernelis essentially a meta-kernel which sequentially calls instances oftfp.mcmc.TransitionKernelon different state parts;GibbsKernelResultstracks the value oftarget_log_prob(on the target scale!), whilst keeping the individual steps' "results" structures stored in a list.GibbsKernel.one_stepsequentially forwardstarget_log_probto eachprevious_resultsstructure before passing it to the respective Gibbs step'sone_stepfunction.- Instances of
tfp.mcmc.TransformedTransitionKernelare detected, and the bijector is invoked in order to maptarget_log_prob-->transformed_target_log_prob. - The mixing of non-gradient based kernels (e.g. RWM) with gradient-based kernels poses an issue with forwarding gradient information. Currently, we take the pragmatic approach of calling
k.bootstrap_resultsto resettarget_log_proband associated gradients ifk'sprevious_resultscontains gradient information. Doubtless we could be cleverer! - A design issue with this kernel is the presence of a Python
listwithin theGibbsKernelResultsnamed_tuple. This breaks the recursive functions such asget_innermostandget_outermost. Ideas welcome.
@ColCarroll ping
A few thoughts after looking over this:
- If you had a complete conditional (which we should for the MVN case?), how would you implement the
make_kernel_fn? Couldmake_kernel_fnreturn atfd.Distribution, and if we get back a distribution we literally calldist.sample(seed=seed)? Or could we add aSamplingKernelthat takes aDistributionandone_stepjust returns another sample? Can you add this as an example? - Could you add an example that incorporates a discrete latent, and showcases using a different kernel for a different state part (MH for discrete, HMC for continuous)?
- It is a bit limiting that we only deal with single state parts at a time, though I guess that could be generalized later. Also a bit limiting to address them by index (though I guess that's more a limitation of the overall TFP MCMC stack supporting only lists and tuples).
- The gradient resetting stuff is icky (but we do similar things in REMC).
- I don't quite follow the necessity of unwrapping
TransformedTransitionKernelinstances. Can you explain why that's necessary? We have some cases of joint distributions where the descendent event space bijectors depend on the ancestral sample values. e.g:
In this case the joint bijector constraining@tfd.JointDistributionCoroutine def model(): low = yield tfd.Normal(0, 1) span = yield tfd.LogNormal(0, 1) x = yield tfd.Uniform(low, low+span) y = yield tfd.Normal(x, 1)xto its support needs the values oflowandlow+spanto shift and scale the sigmoid bijection. My main pt here is that sometimes TTK can act across state parts and not only within a single state part. - Could we add support for "make_kernel_for_the_rest_of_the_states" where only the explicitly specified indices get a make_kernel_fn, and everything else can use (e.g.) a single TTK(HMC(..))? I guess this makes sense to defer to the future as well, alongside (3) above.
- Can you clarify: What about
listandunnestcauses issues?