blackjax
blackjax copied to clipboard
Add HMC Swindles
There are several algorithms in https://arxiv.org/abs/2001.05033 that should be simple to implement.:
-
HMC-COUPLED is a kernel that runs two identical kernels with the same
rng_key;
def hmc_coupled(rng_key, states):
states, infos = jax.vmap(kernel, in_axis=(None, 0))(rng_key, states)
return states, infos
-
HMC-ANTITHETIC is a kernel that runs two kernels with opposite momenta and identical
rng_key.
Seemingly unrelated, but Metropolis-within-Gibbs has been shown to work well (https://github.com/blackjax-devs/blackjax/discussions/275). So this is definitely feasible.