blackjax
blackjax copied to clipboard
pmap seems to drastically improve performance in the example notebook
Description
I am running blackjax
in WSL2
on a 32 core CPU.
I was playing with the example notebook (https://blackjax-devs.github.io/blackjax/examples/Introduction.html#), and noticed that the CPU utilization is actually quite low when running multiple chains.
I have modified the code by first running
import numpyro as npr
npr.util.set_host_device_count(32)
Then I re-used the inference loop from the single-chain example, but instead of using vmap
I used pmap
to parallelize the execution:
rng_key = jax.random.PRNGKey(0)
keys = jax.random.split(rng_key, num_chains)
inference_loop = jax.pmap(
inference_loop, in_axes=(0, 0, None, None), static_broadcasted_argnums=(2, 3)
)
states = inference_loop(keys, initial_states, nuts.step, 1_000)
And this seems to cut the running from 2 minutes to 3 seconds
# vmap
Wall time: 2min 10s
# pmap
Wall time: 2.91 s
Am I doing something wrong here, or should the example actually be adjusted to use pmap
?
Reproducing
See full notebooks here:
https://gist.github.com/elanmart/810f1964738b0ddd8f108b17b7969f82
Setup
Python implementation: CPython
Python version : 3.9.12
IPython version : 8.4.0
jax : 0.3.14
jaxlib : 0.3.14
blackjax: 0.8.2
Compiler : GCC 7.5.0
OS : Linux
Release : 5.10.102.1-microsoft-standard-WSL2
Machine : x86_64
Processor : x86_64
CPU cores : 32
Architecture: 64bit
It is expected that pmap is faster than vmap using NUTS: in vmap at each sampling step, we are waiting for the chain with the longest leapfrog. This could potentially explain the CPU utilization as most of the time the other chains are finished one sample and just waiting for few chains with large number of leapfrog But in this case the speed differences is pretty huge, likely because of the poor performance of an un-tuned NUTS.
And yes I think it is a great idea to add a pmap example! I think we dont have a lot of those currently.
I agree with what @junpenglao said: I would expect that to happen with NUTS; each step can only be as fast as the slowest chain and these delays can add up to quite a lot after a few thousand steps. As opposed to pmap
which will run the chains completely independently.
Would you like to add an example with pmap
at the end of this notebook with a short explanation of the difference @elanmart ?
Thanks a lot for the explanation! I'll be happy to open a PR adding a small section with pmap
Apologies, I forgot about this issue, I'll get to it this week and open a PR once #243 is merged.
No problem! Thank you for letting us know.