numpyro icon indicating copy to clipboard operation
numpyro copied to clipboard

Distributed lpdf / grad evaluation

Open rborder opened this issue 3 years ago • 1 comments

I would like to be able to use jax's pjit to parallelize density / gradient evaluation across multiple GPUs. This would allow here here you to perform standard NUTS/HMC when data are too large to fit on a single device, but multiple devices are available (currently the only option would be to use a data subsampling approach, and I believe multiple devices are only usable for embarrassingly parallel tasks such as running multiple chains at once). $\sigma$

Here is a toy example corresponding to the generative model: $\sigma_e,\sigma_\beta\sim\mathrm{priors},$ $\beta \vert \sigma_e,\sigma_\beta \sim N(0, \sigma_\beta),$ $y \vert\beta, \sigma_e,\sigma_\beta \sim N(X\beta, \sigma_e),$ where I'd like to distribute the computation of the matrix vector product $X\beta$ and the likelihood of $y$. We can generate these data using

import numpy as np
import jax
import numpyro
import jax.random as random
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
import jax.numpy as jnp

## dataset dimensions
M=100; N=200

## true parameter values
sbeta_true = np.sqrt(.5)
se_true = np.sqrt(.5)
beta_true = np.random.randn(M)*sbeta_true

## observed data
X = np.random.randn(M*N).reshape(N,M)
X = np.apply_along_axis(lambda x: (x-np.mean(x))/np.std(x), 0, X) / np.sqrt(M)
e = np.random.randn(N)*se_true
y = X @ beta_true + e
y -= np.mean(y)

and fit the standard, single device version using

def lpdf(X,b,y,se):
    return jnp.sum(dist.Normal(0., se).log_prob(y-jnp.dot(X,b)))

def toy_model(y=None, X=None):
    s_beta = numpyro.sample('s_beta', dist.HalfCauchy(1))
    s_e = numpyro.sample('s_e', dist.HalfCauchy(1),sample_shape=(1,))
    beta = numpyro.sample('beta', dist.Normal(0.,s_beta),sample_shape=(X.shape[1],))
    lpy=lpdf(X, beta, y, s_e)
    numpyro.factor('y',lpy)
    
## construct kernel
nuts_kernel = NUTS(toy_model)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=500)

rng_key = random.PRNGKey(0)

## run model
mcmc.run(rng_key, y=y, X=X) 

A distributed version of this model can be implementing by sharding the $X$ and $y$ across devices and using pjit:

from jax.experimental import pjit

## define log density
def _lpdf(X,y,b,sb,se):
    ## priors
    ll = dist.HalfCauchy(1).log_prob(sb) + dist.HalfCauchy(1).log_prob(se)
    ## latent variables
    ll += jnp.sum(dist.Normal(0., sb).log_prob(b))
    ## distributed part involving observed data
    ll += jax.lax.psum(dist.Normal(0., se).log_prob(y-jnp.dot(X,b)),0)[0]
    return ll

## function to distribute rows of X,y across mesh
shard = pjit(
    lambda x: x,
    in_axis_resources=None,
    out_axis_resources=PartitionSpec('n', 'm'))

## distributed likelihood
pjit_lpdf = pjit(_lpdf,
                 in_axis_resources=(PartitionSpec('n', 'm'), PartitionSpec('n', 'm'),
                                    None, None, None),
                 out_axis_resources=None)
pgrad_lpdf = pjit(grad_lpdf,
                 in_axis_resources=(PartitionSpec('n', 'm'), PartitionSpec('n', 'm'),
                                    None, None, None),
                 out_axis_resources=None)
                 
## shard data across devices
shard = pjit(
    lambda x: x,
    in_axis_resources=None,
    out_axis_resources=PartitionSpec('n', 'm'))

with maps.Mesh(mesh.devices, mesh.axis_names):
    X_sharded = shard(X)
    y_sharded = shard(y)

## evaluate multi-gpu lpdf/grad
with maps.Mesh(mesh.devices, mesh.axis_names):
    lp = pjit_lpdf(X_sharded,y_sharded,beta_true,.7,.7)
    lgrad = pgrad_lpdf(X_sharded,y_sharded,beta_true,.7,.7)

However, there’s no way to sample this lpdf using numpyro's samplers because numpyro will attempt to jit the already pjit'd density, which throws an error. I'd like to be able to perform HMC in a distributed framework using a pjit'd density. See (https://forum.pyro.ai/t/possible-to-use-pmap-within-likelihood-computation/4189 for the original discussion leading to this issue).

rborder avatar Jun 08 '22 20:06 rborder

Just in case it's helpful, there seems to be an experimental implementation of this sort of functionality (MCMC using jax with sharded data) in tensorflow probability here, though it would be great to be able to do this sort of thing within numpyro!

rborder avatar Jun 13 '22 18:06 rborder

Sorry for the slow response. I'm looking into this issue.

fehiepsi avatar Nov 29 '22 14:11 fehiepsi

This would be a phenomenal feature to support--it would certainly simplify a lot of necessary calculations for statistical genetics.

quattro avatar Mar 17 '23 18:03 quattro

This is supported in jax 0.4.4 and newer. See the example in #1514

fehiepsi avatar Mar 19 '23 12:03 fehiepsi