numpyro
numpyro copied to clipboard
GMM notebook example: MCMC/NUTS simulation is not reproducible
source: https://num.pyro.ai/en/stable/tutorials/gmm.html#MCMC
numpyro.__version__: 0.12.1
jax.__version__: 0.4.13
--
When running the collapsed NUTS to explore the full posterior, the obtained results did not match the presented ones.
from numpyro.infer import MCMC, NUTS
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=50, num_samples=250)
mcmc.run(random.PRNGKey(2), data)
mcmc.print_summary()
posterior_samples = mcmc.get_samples()
Obtained posterior density:
But, with longer num_warmup (to 150 or more), we get roughly the expected behaviour:
With more samples (~2500), the pattern is better:
I would like to clarify that in my previous attempts, I used the same values and parameters for reproduction. When I ran the code on Google Drive, the results matched the ones mentioned in the documentation. However, when I ran the code on my laptop, there was a significant difference. I can provide you with more details regarding this issue. Considering the specified random seeds and the simplicity of the example, I find the difference to be quite substantial.
I would like also to thank all the contributors for this library ! I am impressed and excited by the remarkable work done by its developers.
@fehiepsi can this be due to a newer (different) version of jax?
Yes, it's likely caused by numerical changes. It makes sense to use higher warmup and larger num_samples. Currently, rhat, n_eff in the checked-in version is pretty poor.