Register the overloads added by CustomDist in worker processes
Description
Currently sample_smc can fail due to a NotImplementedError if it's used with a model defined usingCustomDist. If a CustomDist is used without dist parameter, the overloads for _logprob, _logcdf and _support_point are registered only in the main process.
This PR adds an initializer which registers the overloads in the worker processes of the pool used in sample_smc.
Related Issue
- [x] Closes #7224
- [ ] Related to #
Checklist
- [x] Checked that the pre-commit linting/style checks pass
- [x] Included tests that prove the fix is effective or that the new feature works
- [x] Added necessary documentation (docstrings and/or example notebooks)
- [x] If you are a pro: each commit corresponds to a relevant logical change
Type of change
- [ ] New feature / enhancement
- [x] Bug fix
- [ ] Documentation
- [ ] Maintenance
- [ ] Other (please specify):
📚 Documentation preview 📚: https://pymc--7241.org.readthedocs.build/en/7241/
]
:sparkling_heart: Thanks for opening this pull request! :sparkling_heart: The PyMC community really appreciates your time and effort to contribute to the project. Please make sure you have read our Contributing Guidelines and filled in our pull request template to the best of your ability.
Hi @EliasRas, I'll need some time to review this one properly. Thanks for taking the initiative
Looks like I messed up by rebasing instead of merging and introduced plenty of unnecessary commits to this feature. Does it need to be fixed?
Yes, that needs to be fixed, happens to everyone. One approach is to start clean and cherry-pick your commits.
I think it's more complicated than this. The following example has specific dispatch, but no RV that shows up in the graph:
import pymc as pm
def _logp(value, mu):
return -((value - mu) ** 2)
def _dist(mu, size=None):
return pm.Normal.dist(mu, 1, size=size)
with pm.Model():
mu = pm.Normal("mu", 0)
pm.Potential("term", pm.logp(pm.CustomDist.dist(mu, logp=_logp, dist=_dist), [1, 2]))
pm.sample_smc(draws=6, cores=1)
It also fails even with a single core
It also fails even with a single core
22e8f0bb4a02d874856438065efa7b3ef2645e13 did refactoring for sample_smc and I think that errors should now pop up even with single core since the sampling is always done in another process. Previously this was the case only when cores>1 since there were separate run_chains_parallel and run_chains_sequential.
Somehow, in main, I am getting ConnectionResetError: [Errno 104] Connection reset by peer even for unrelated models without any sort of CustomDist
Somehow, in main, I am getting
ConnectionResetError: [Errno 104] Connection reset by peereven for unrelated models without any sort of CustomDist
Okay it's something about the new progressbal and pycharm interactive python console. If I use from ipython/terminal it works. But also works in main for me?
I cannot reproduce a failure with your test locally (after avoding the pycharm issue) nor in a Colab environment: https://colab.research.google.com/drive/1I1n6c9IlmXknIfhxC5s7sAQghv0vfRSY?usp=sharing
Can you share more details about your environment/setup?
Can you share more details about your environment/setup?
I added the output of conda list to "PyMC version information" section of #7224. I'm running the code using VSCode if that matters. Do you need anything else?
Basically I followed the install instructions and the pull request tutorial when installing. Might have also pip installed a couple of extra packages here and there.
I added the output of conda list to "PyMC version information" section of https://github.com/pymc-devs/pymc/issues/7224. I'm running the code using VSCode if that matters. Do you need anything else?
We should have at least one person reproduce the problem because I cannot. It may be a VSCode environment issue. Ideally we wouldn't have to change the codebase
The test does fail without the changes when I run it from miniforge prompt though.
The test does fail without the changes when I run it from miniforge prompt though.
Not sure what miniforge prompt is, can we try to reproduce here on the CI then? Push just the test without the fixes into a new PR and well run it to see if we can reproduce
Is there anything that needs to be done here besides running the tests?
Is there anything that needs to be done here besides running the tests?
Sorry for the delay, just kicked off tests.
Codecov Report
All modified and coverable lines are covered by tests :white_check_mark:
Project coverage is 92.19%. Comparing base (
c8b22df) to head (d2669f2). Report is 2 commits behind head on main.
Additional details and impacted files
@@ Coverage Diff @@
## main #7241 +/- ##
==========================================
+ Coverage 92.18% 92.19% +0.01%
==========================================
Files 103 103
Lines 17259 17282 +23
==========================================
+ Hits 15910 15933 +23
Misses 1349 1349
| Files | Coverage Δ | |
|---|---|---|
| pymc/smc/sampling.py | 99.34% <100.00%> (+0.11%) |
:arrow_up: |
Thanks @EliasRas, I haven't been able to reproduce this yet but that's just because I'm in the middle of switching workstations and haven't gotten everything setup yet.
Your fix looks fine to me and I understand what you identified as the cause of the issue: the dispatching mechanism isn't registering the logp and other methods to the dynamically created class. I think that this highlights a caveat in pymc's and pytensor's design: spawned processes may not have all the registered dispatch signatures as the main process. I imagine that this is mostly a problem on Windows, where multiprocessing can only spawn new processes whereas linux based systems will default to forks which in principle should copy over the memory contents of the main processes. I'm not sure what will happen under MacOS because I think that they cannot use fork multiprocessing for some reason either.
With this design caveat in hand, I'm not sure if it's better to have a package level utility function that serves as a sort of book-keeper or something that can handle communicating the extra dispatch registration needed to ensure that child processes will use the correct dispatching functions. I'm curious to know what @ricardoV94 thinks about this. I don't think that this PR should have to tackle this kind of work, but I think that we can discuss if it's necessary here, and maybe later open an issue and a separate PR (also maybe in pytensor where dispatching is used for transpilation/compilation and maybe at some point for lazy gradients?).
I guess the underlying reason for the failure is that pickling of DensityDist doesn't work out of the box? Sounds like for some reason the dispatch functions don't get registered when the object is unpickled. But wouldn't it be cleaner to overwrite the pickling behavior of this class then? We could override __getstate__ and __setstate__ methods to that effect?
I guess the underlying reason for the failure is that pickling of
DensityDistdoesn't work out of the box?
I don’t think the problem is about pickling. The DensityDist end up returning an op that can be cloudpickled. If I recall correctly it can’t be pickled because the op class is created on the fly. In the process of creating the op, the dispatchers get populated with the callables that are supplied as inputs to the distribution class. As far as I understand, those functions are detached from the rv op and that’s why they never get populated on a spawned process.
I don't mean that the pickling itself throws an error (it doesn't), but that it would be the responsibility of the DensityDist object to ensure that the set-up it needs (ie registering the logp) is done when it is unpickled.
For instance the following fails with the NotImplementedError, and has nothing to do with smc, so I guess the solution shouldn't be specific to smc?
import pymc as pm
import cloudpickle
import multiprocessing
def use_logp_func(pickled_model):
model = cloudpickle.loads(pickled_model)
logp = model.logp()
func = pm.pytensorf.compile_pymc(model.value_vars, logp)
print(func(1.0))
if __name__ == "__main__":
with pm.Model() as model:
def logp(value):
return -(value**2)
pm.DensityDist("x", logp=logp)
logp = model.logp()
func = pm.pytensorf.compile_pymc(model.value_vars, logp)
pickled_model = cloudpickle.dumps(model)
ctx = multiprocessing.get_context("spawn")
process = ctx.Process(target=use_logp_func, args=(pickled_model,))
process.start()
process.join()
I completely agree that this problem isn’t unique to smc and is a design caveat that needs to be addressed more comprehensively. I think that we can kind of patch some things:
- Make
Modelobjects__setstate__and__getstate__repopulate the dispatch registries - Get
CustomDistrv ops to have these methods defined somehow (maybe clojures) that repopulates the dispatch registries.
I’m not sure if these two methods can cover all use patterns though.
Alternatively we could pass the functions needed to each process which is more like what pm.sample does.
This also avoids recompiling the same functions multiple times?
@lucianopaz
Point 1. is pretty straightforward but could you explain what you meant by 2.? How would it be different from overriding __getstate__ and __setstate__?
I started work on __getstate__ and __setstate__ but I realized that I can't just copy the current implementation due to circular imports. Would either of these sound like a good idea?
- Create
utilsmodule (or something similar) inpymc.modelthat handles the custom methods and registering them. - Add a method for RVs that registers the custom methods if necessary.
CustomDistRVandCustomSymbolicDistRVwould have the registrations andRandomVariableandSymbolicRandomVariablewould just pass.
I think we should explore an alternative where we compile the functions SMC needs and fork afterwards like pm.sample does.
This approach seems more brittle?
It would also avoid re-compiling the same functions in each chain
I agree on the re-compiling part but shouldn't this still be fixed? It feels like an arbitrary decision to "disallow" using multiprocessing this way only on Windows even if it is a bad way.
I agree on the re-compiling part but shouldn't this still be fixed? It feels like an arbitrary decision to "disallow" using multiprocessing this way only on Windows even if it is a bad way.
I think this limitation is likely deeper than what you're addressing here. As @aseyboldt and @lucianopaz mentioned we're using dynamic dispatching as a recurring theme in our codebase and pytensor's
However, I don't agree with their solutions
Using the class that's being dispatched to register the dispatches during pickling seems at odds with the point of dispatching. The class shouldn't have to know what's being dispatched upon.
For instance we also have icdf methods, what if someone dispatched on it from the outside, does pickling work for it? Or would the setstate/getstate need to know about icdf (as well any other dispatch that may not even be part of PyMC)?
It's also not a PyMC model responsibility. CustomDist can be defined just fine outside of a model
Thank you for taking the time to explain. I'll start working on the compilation approach.
However since this fixes existing behavior I think we can go ahead and merge it as a temporary patch?