numpyro icon indicating copy to clipboard operation
numpyro copied to clipboard

HMCGibbs with chain_method=”vectorized”

Open WolfgangEnzi opened this issue 1 year ago • 4 comments

I am trying to use HMCGibbs sampling with more than one chain using chain_method=“vectorized”, but there appears to be some problem with splitting the random keys.

Consider this toy example that I copied from the numpyro documentation, where I only changed the chain_method and the number of chains:

from jax import random
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, HMCGibbs

def model():
    x = numpyro.sample("x", dist.Normal(0.0, 2.0))
    y = numpyro.sample("y", dist.Normal(0.0, 2.0))
    numpyro.sample("obs", dist.Normal(x + y, 1.0), obs=jnp.array([1.0]))

def gibbs_fn(rng_key, gibbs_sites, hmc_sites):
    y = hmc_sites['y']
    new_x = dist.Normal(0.8 * (1-y), jnp.sqrt(0.8)).sample(rng_key)
    return {'x': new_x}

hmc_kernel = NUTS(model)
kernel = HMCGibbs(hmc_kernel, gibbs_fn=gibbs_fn, gibbs_sites=['x'])
mcmc = MCMC(kernel, num_warmup=100, num_chains=2, num_samples=100, progress_bar=False, chain_method='vectorized',)
mcmc.run(random.PRNGKey(0))
mcmc.print_summary()

I find that I get the following Error when running the above Code: TypeError: split accepts a single key, but was given a key array of shape (2,) != (). Use jax.vmap for batching.

Is there a way to make the vectorize option available for HMCGibbs sampling?

WolfgangEnzi avatar Jan 30 '24 16:01 WolfgangEnzi

Could you change this line to jax.vmap(...) with the default parallel method to see if it works for HMCGibbs?

fehiepsi avatar Jan 30 '24 18:01 fehiepsi

Choosing the "parallel" option and changing to jax.vmap did not work for me. It seems like it still processes the chains in sequential order when I do that.

WolfgangEnzi avatar Feb 13 '24 17:02 WolfgangEnzi

Did you set host device to the number of chains: https://num.pyro.ai/en/stable/utilities.html#set-host-device-count?

fehiepsi avatar Feb 15 '24 21:02 fehiepsi

I think the HMCGibbs class's init method needs something similar to https://github.com/pyro-ppl/numpyro/blob/master/numpyro/infer/hmc.py#L782-L790 to detect if it is getting one key or a list of keys and vmap its initialization and sampling functions as needed.

When I was writing a custom Gibbs sampler (that does an HMC step for each conditional rather than drawing from a known distribution), I was able to get vectorized working this way, so it should be doable for this sampler as well.

I imagine it would look a bit like:

def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
    model_kwargs = {} if model_kwargs is None else model_kwargs.copy()
    def init_fn(init_parms, rng_key):
        ...
       return HMCGibbsState(z, hmc_stat, rng_key)
    if is_prng_key(rng_key):
        init_state = init_fn(init_params, rng_key)
        self._sample_fn = self._sample_one_chain
    else:
        init_state = vmap(init_fn)(init_params, rng_key)
        self._sample_fn = vmap(self._sample_one_chian, in_axis=(0, None, None))
    return device_put(init_state)

and rename the current sample method _sample_one_chain and make a new sample that calls self._sample_fn.

Might need a bit of extra logic around to work as expected but I think it is what the solution would look like.

CKrawczyk avatar Feb 23 '24 12:02 CKrawczyk