blackjax icon indicating copy to clipboard operation
blackjax copied to clipboard

pmap seems to drastically improve performance in the example notebook

Open elanmart opened this issue 2 years ago • 4 comments

Description

I am running blackjax in WSL2 on a 32 core CPU.

I was playing with the example notebook (https://blackjax-devs.github.io/blackjax/examples/Introduction.html#), and noticed that the CPU utilization is actually quite low when running multiple chains.

I have modified the code by first running

import numpyro as npr
npr.util.set_host_device_count(32)

Then I re-used the inference loop from the single-chain example, but instead of using vmap I used pmap to parallelize the execution:

rng_key = jax.random.PRNGKey(0)

keys = jax.random.split(rng_key, num_chains)
inference_loop = jax.pmap(
    inference_loop, in_axes=(0, 0, None, None), static_broadcasted_argnums=(2, 3)
)

states = inference_loop(keys, initial_states, nuts.step, 1_000)

And this seems to cut the running from 2 minutes to 3 seconds

# vmap
Wall time: 2min 10s
# pmap
Wall time: 2.91 s

Am I doing something wrong here, or should the example actually be adjusted to use pmap?

Reproducing

See full notebooks here:
https://gist.github.com/elanmart/810f1964738b0ddd8f108b17b7969f82

Setup

Python implementation: CPython
Python version       : 3.9.12
IPython version      : 8.4.0

jax     : 0.3.14
jaxlib  : 0.3.14
blackjax: 0.8.2

Compiler    : GCC 7.5.0
OS          : Linux
Release     : 5.10.102.1-microsoft-standard-WSL2
Machine     : x86_64
Processor   : x86_64
CPU cores   : 32
Architecture: 64bit

elanmart avatar Jul 10 '22 16:07 elanmart

It is expected that pmap is faster than vmap using NUTS: in vmap at each sampling step, we are waiting for the chain with the longest leapfrog. This could potentially explain the CPU utilization as most of the time the other chains are finished one sample and just waiting for few chains with large number of leapfrog But in this case the speed differences is pretty huge, likely because of the poor performance of an un-tuned NUTS.

junpenglao avatar Jul 11 '22 05:07 junpenglao

And yes I think it is a great idea to add a pmap example! I think we dont have a lot of those currently.

junpenglao avatar Jul 11 '22 05:07 junpenglao

I agree with what @junpenglao said: I would expect that to happen with NUTS; each step can only be as fast as the slowest chain and these delays can add up to quite a lot after a few thousand steps. As opposed to pmap which will run the chains completely independently.

Would you like to add an example with pmap at the end of this notebook with a short explanation of the difference @elanmart ?

rlouf avatar Jul 11 '22 08:07 rlouf

Thanks a lot for the explanation! I'll be happy to open a PR adding a small section with pmap

elanmart avatar Jul 11 '22 17:07 elanmart

Apologies, I forgot about this issue, I'll get to it this week and open a PR once #243 is merged.

elanmart avatar Aug 28 '22 23:08 elanmart

No problem! Thank you for letting us know.

rlouf avatar Aug 29 '22 13:08 rlouf