probability icon indicating copy to clipboard operation
probability copied to clipboard

Slow sampling of NegativeBinomial distribution

Open stefanheyder opened this issue 1 year ago • 0 comments

Sampling from the Negative Binomial distribution (using jax substrates), especially using a small total_count parameter is very slow, compared to a jax only implementation.

import tensorflow_probability as tfp
tfp.__version__

'0.23.0'

from tensorflow_probability.substrates.jax.distributions import (
    NegativeBinomial as NBinom,
)
from jax import numpy as jnp, random as jrn, config as config
config.update("jax_enable_x64", True)

N = 1000
mu = 1e4
small_r = 0.1
middle_r = 10
large_r = 1000
key = jrn.PRNGKey(342354234)

nbinom_small_r = NBinom(total_count=small_r, logits=jnp.log(mu) - jnp.log(small_r))
nbinom_middle_r = NBinom(total_count=middle_r, logits=jnp.log(mu) - jnp.log(middle_r))
nbinom_large_r = NBinom(total_count=large_r, logits=jnp.log(mu) - jnp.log(large_r))

%timeit nbinom_small_r.sample(seed=key, sample_shape=(N,)).block_until_ready()
%timeit nbinom_middle_r.sample(seed=key, sample_shape=(N,)).block_until_ready()
%timeit nbinom_large_r.sample(seed=key, sample_shape=(N,)).block_until_ready()

6 s ± 38 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 235 ms ± 710 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) 119 ms ± 87.5 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

def sample_nbinom(key, r, mu, shp):
    key, sk_g, sk_p = jrn.split(key, 3)
    gamma_sample = mu / r * jrn.gamma(sk_g, r, shp)
    return jrn.poisson(sk_p, gamma_sample)
%timeit sample_nbinom(jrn.PRNGKey(0), small_r, mu, (N,)).block_until_ready()
%timeit sample_nbinom(jrn.PRNGKey(0), middle_r, mu, (N,)).block_until_ready()
%timeit sample_nbinom(jrn.PRNGKey(0), large_r, mu, (N,)).block_until_ready()

747 µs ± 65.8 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) 604 µs ± 38 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) 532 µs ± 867 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

I'm using tfp 0.23.0 under Python 3.10.13, as sampling under 0.24.0 with Python 3.12. does not work for me (I encounter similar behavior as in #1838).

stefanheyder avatar Sep 28 '24 11:09 stefanheyder