pyro
pyro copied to clipboard
MCMC with parallel chains get stuck in jupyter notebook on Ubuntu
As reported by @activatedgeek. MCMC with multiprocessing seems to work fine on mac, and even Ubuntu when run from the terminal, but gets stuck in Ubuntu when run in the notebook environment. I am opening this issue so that we can investigate further if this is a common issue faced by other users.
Jupyter version:
jupyter-client==5.2.4
jupyter-console==6.0.0
jupyter-core==4.4.0
jupyterlab==1.2.5
jupyterlab-server==1.0.6
I believe this is related to these issues: https://github.com/pytorch/pytorch/issues/17680 and https://github.com/pytorch/pytorch/issues/20375.
Not related but I put here some observations in a terminal for reference:
- Without
torch.multiprocessing.set_start_method("spawn")
(orforkserver
) line, baseball example will fail with the defaultfork
method due to this issue. - Under
spawn
method, the default sharing strategyfile_descriptor
in Linux will lead to bad value(s) in fds_in_keep issue. This does not happen withforkserver
.
I think #20375 is the reason why this fails - the default on Ubuntu and python 3.8 across all platforms is "spawn" context which causes issues. I can replicate it in mac too if I set mp_context="spawn"
. I suppose one solution would be to change the default depending on the notebook environment.
@activatedgeek - In the meantime, to get around this issue, you could try to use mp_context='forkserver'
when running in the notebook. Let us know if that works, or still causes issues.
It might be better to add a note/warning to MCMC
class so that users know about these gross of PyTorch multiprocessing. WDYT? I worry about forkserver
might not be available on some platform.
Thanks for the info @fehiepsi
@neerajprad The issue seems to still exist with mp_context='forkserver'
.
Yes, thanks @fehiepsi for the info and all your great help in general.
I also bumped into this issue (on CentOS 7.9.2009), and I actually do not need the chains to run in parallel.
I need multiple chains for investigating convergence (In this case computing r_hat statistics).
Wondering if a fix for this ever emerged? num_chains>1
still fails when run in a jupyter notebook.
Hi @AndrewFalkowski I think the issue has not been fixed in Pyro. I'm not sure what is the current status of the related PyTorch issues mentioned above.
Ah, bummer, I typically like to build models in jupyter notebooks before migrating to a proper script, guess I'll have to rework my workflow a little. Thanks!