numpyro
                                
                                 numpyro copied to clipboard
                                
                                    numpyro copied to clipboard
                            
                            
                            
                        HMCGibbs with chain_method=”vectorized”
I am trying to use HMCGibbs sampling with more than one chain using chain_method=“vectorized”, but there appears to be some problem with splitting the random keys.
Consider this toy example that I copied from the numpyro documentation, where I only changed the chain_method and the number of chains:
from jax import random
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, HMCGibbs
def model():
    x = numpyro.sample("x", dist.Normal(0.0, 2.0))
    y = numpyro.sample("y", dist.Normal(0.0, 2.0))
    numpyro.sample("obs", dist.Normal(x + y, 1.0), obs=jnp.array([1.0]))
def gibbs_fn(rng_key, gibbs_sites, hmc_sites):
    y = hmc_sites['y']
    new_x = dist.Normal(0.8 * (1-y), jnp.sqrt(0.8)).sample(rng_key)
    return {'x': new_x}
hmc_kernel = NUTS(model)
kernel = HMCGibbs(hmc_kernel, gibbs_fn=gibbs_fn, gibbs_sites=['x'])
mcmc = MCMC(kernel, num_warmup=100, num_chains=2, num_samples=100, progress_bar=False, chain_method='vectorized',)
mcmc.run(random.PRNGKey(0))
mcmc.print_summary()
I find that I get the following Error when running the above Code:
TypeError: split accepts a single key, but was given a key array of shape (2,) != (). Use jax.vmap for batching.
Is there a way to make the vectorize option available for HMCGibbs sampling?
Could you change this line to jax.vmap(...) with the default parallel method to see if it works for HMCGibbs?
Choosing the "parallel" option and changing to jax.vmap did not work for me. It seems like it still processes the chains in sequential order when I do that.
Did you set host device to the number of chains: https://num.pyro.ai/en/stable/utilities.html#set-host-device-count?
I think the HMCGibbs class's init method needs something similar to https://github.com/pyro-ppl/numpyro/blob/master/numpyro/infer/hmc.py#L782-L790 to detect if it is getting one key or a list of keys and vmap its initialization and sampling functions as needed.
When I was writing a custom Gibbs sampler (that does an HMC step for each conditional rather than drawing from a known distribution), I was able to get vectorized working this way, so it should be doable for this sampler as well.
I imagine it would look a bit like:
def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
    model_kwargs = {} if model_kwargs is None else model_kwargs.copy()
    def init_fn(init_parms, rng_key):
        ...
       return HMCGibbsState(z, hmc_stat, rng_key)
    if is_prng_key(rng_key):
        init_state = init_fn(init_params, rng_key)
        self._sample_fn = self._sample_one_chain
    else:
        init_state = vmap(init_fn)(init_params, rng_key)
        self._sample_fn = vmap(self._sample_one_chian, in_axis=(0, None, None))
    return device_put(init_state)
and rename the current sample method _sample_one_chain and make a new sample that calls self._sample_fn.
Might need a bit of extra logic around to work as expected but I think it is what the solution would look like.