Distributed lpdf / grad evaluation
I would like to be able to use jax's pjit to parallelize density / gradient evaluation across multiple GPUs. This would allow here here you to perform standard NUTS/HMC when data are too large to fit on a single device, but multiple devices are available (currently the only option would be to use a data subsampling approach, and I believe multiple devices are only usable for embarrassingly parallel tasks such as running multiple chains at once). $\sigma$
Here is a toy example corresponding to the generative model: $\sigma_e,\sigma_\beta\sim\mathrm{priors},$ $\beta \vert \sigma_e,\sigma_\beta \sim N(0, \sigma_\beta),$ $y \vert\beta, \sigma_e,\sigma_\beta \sim N(X\beta, \sigma_e),$ where I'd like to distribute the computation of the matrix vector product $X\beta$ and the likelihood of $y$. We can generate these data using
import numpy as np
import jax
import numpyro
import jax.random as random
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
import jax.numpy as jnp
## dataset dimensions
M=100; N=200
## true parameter values
sbeta_true = np.sqrt(.5)
se_true = np.sqrt(.5)
beta_true = np.random.randn(M)*sbeta_true
## observed data
X = np.random.randn(M*N).reshape(N,M)
X = np.apply_along_axis(lambda x: (x-np.mean(x))/np.std(x), 0, X) / np.sqrt(M)
e = np.random.randn(N)*se_true
y = X @ beta_true + e
y -= np.mean(y)
and fit the standard, single device version using
def lpdf(X,b,y,se):
return jnp.sum(dist.Normal(0., se).log_prob(y-jnp.dot(X,b)))
def toy_model(y=None, X=None):
s_beta = numpyro.sample('s_beta', dist.HalfCauchy(1))
s_e = numpyro.sample('s_e', dist.HalfCauchy(1),sample_shape=(1,))
beta = numpyro.sample('beta', dist.Normal(0.,s_beta),sample_shape=(X.shape[1],))
lpy=lpdf(X, beta, y, s_e)
numpyro.factor('y',lpy)
## construct kernel
nuts_kernel = NUTS(toy_model)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=500)
rng_key = random.PRNGKey(0)
## run model
mcmc.run(rng_key, y=y, X=X)
A distributed version of this model can be implementing by sharding the $X$ and $y$ across devices and using pjit:
from jax.experimental import pjit
## define log density
def _lpdf(X,y,b,sb,se):
## priors
ll = dist.HalfCauchy(1).log_prob(sb) + dist.HalfCauchy(1).log_prob(se)
## latent variables
ll += jnp.sum(dist.Normal(0., sb).log_prob(b))
## distributed part involving observed data
ll += jax.lax.psum(dist.Normal(0., se).log_prob(y-jnp.dot(X,b)),0)[0]
return ll
## function to distribute rows of X,y across mesh
shard = pjit(
lambda x: x,
in_axis_resources=None,
out_axis_resources=PartitionSpec('n', 'm'))
## distributed likelihood
pjit_lpdf = pjit(_lpdf,
in_axis_resources=(PartitionSpec('n', 'm'), PartitionSpec('n', 'm'),
None, None, None),
out_axis_resources=None)
pgrad_lpdf = pjit(grad_lpdf,
in_axis_resources=(PartitionSpec('n', 'm'), PartitionSpec('n', 'm'),
None, None, None),
out_axis_resources=None)
## shard data across devices
shard = pjit(
lambda x: x,
in_axis_resources=None,
out_axis_resources=PartitionSpec('n', 'm'))
with maps.Mesh(mesh.devices, mesh.axis_names):
X_sharded = shard(X)
y_sharded = shard(y)
## evaluate multi-gpu lpdf/grad
with maps.Mesh(mesh.devices, mesh.axis_names):
lp = pjit_lpdf(X_sharded,y_sharded,beta_true,.7,.7)
lgrad = pgrad_lpdf(X_sharded,y_sharded,beta_true,.7,.7)
However, there’s no way to sample this lpdf using numpyro's samplers because numpyro will attempt to jit the already pjit'd density, which throws an error. I'd like to be able to perform HMC in a distributed framework using a pjit'd density. See (https://forum.pyro.ai/t/possible-to-use-pmap-within-likelihood-computation/4189 for the original discussion leading to this issue).
Just in case it's helpful, there seems to be an experimental implementation of this sort of functionality (MCMC using jax with sharded data) in tensorflow probability here, though it would be great to be able to do this sort of thing within numpyro!
Sorry for the slow response. I'm looking into this issue.
This would be a phenomenal feature to support--it would certainly simplify a lot of necessary calculations for statistical genetics.
This is supported in jax 0.4.4 and newer. See the example in #1514