numpyro icon indicating copy to clipboard operation
numpyro copied to clipboard

pmap over `num_particles` in SVI

Open snehjp2 opened this issue 2 years ago • 1 comments

Hi,

In Trace_ELBO, the num_particles argument allows one to effectively introduce a batch size in estimating the ELBO gradient if num_particles > 1. By default, it's vectorized over the num_particles. Is it possible to also distribute the batch dimension over devices (e.g. when running on multiple GPUs). My particular application is prone to jax OOM errors and would benefit from distribution over jax.pmap.

snehjp2 avatar Sep 22 '23 20:09 snehjp2

If you got OOM, you can set vectorize particles to False. You can also use PositionalSharding like in MCMC I guess.

If you want to pmap over particles, could you make a PR for it? I think we can just simply allow a callable vectorize_particles and call it here.

fehiepsi avatar Sep 22 '23 22:09 fehiepsi