pymc icon indicating copy to clipboard operation
pymc copied to clipboard

Register the overloads added by CustomDist in worker processes

Open EliasRas opened this issue 1 year ago • 32 comments

Description

Currently sample_smc can fail due to a NotImplementedError if it's used with a model defined usingCustomDist. If a CustomDist is used without dist parameter, the overloads for _logprob, _logcdf and _support_point are registered only in the main process.

This PR adds an initializer which registers the overloads in the worker processes of the pool used in sample_smc.

Related Issue

  • [x] Closes #7224
  • [ ] Related to #

Checklist

Type of change

  • [ ] New feature / enhancement
  • [x] Bug fix
  • [ ] Documentation
  • [ ] Maintenance
  • [ ] Other (please specify):

📚 Documentation preview 📚: https://pymc--7241.org.readthedocs.build/en/7241/

EliasRas avatar Apr 06 '24 16:04 EliasRas

Thank You Banner] :sparkling_heart: Thanks for opening this pull request! :sparkling_heart: The PyMC community really appreciates your time and effort to contribute to the project. Please make sure you have read our Contributing Guidelines and filled in our pull request template to the best of your ability.

welcome[bot] avatar Apr 06 '24 16:04 welcome[bot]

Hi @EliasRas, I'll need some time to review this one properly. Thanks for taking the initiative

ricardoV94 avatar Apr 09 '24 08:04 ricardoV94

Looks like I messed up by rebasing instead of merging and introduced plenty of unnecessary commits to this feature. Does it need to be fixed?

EliasRas avatar May 26 '24 04:05 EliasRas

Yes, that needs to be fixed, happens to everyone. One approach is to start clean and cherry-pick your commits.

twiecki avatar May 26 '24 11:05 twiecki

I think it's more complicated than this. The following example has specific dispatch, but no RV that shows up in the graph:

import pymc as pm

def _logp(value, mu):
    return -((value - mu) ** 2)

def _dist(mu, size=None):
    return pm.Normal.dist(mu, 1, size=size)

with pm.Model():
    mu = pm.Normal("mu", 0)
    pm.Potential("term", pm.logp(pm.CustomDist.dist(mu, logp=_logp, dist=_dist), [1, 2]))
    pm.sample_smc(draws=6, cores=1)        

It also fails even with a single core

ricardoV94 avatar May 27 '24 10:05 ricardoV94

It also fails even with a single core

22e8f0bb4a02d874856438065efa7b3ef2645e13 did refactoring for sample_smc and I think that errors should now pop up even with single core since the sampling is always done in another process. Previously this was the case only when cores>1 since there were separate run_chains_parallel and run_chains_sequential.

EliasRas avatar May 27 '24 12:05 EliasRas

Somehow, in main, I am getting ConnectionResetError: [Errno 104] Connection reset by peer even for unrelated models without any sort of CustomDist

ricardoV94 avatar May 27 '24 13:05 ricardoV94

Somehow, in main, I am getting ConnectionResetError: [Errno 104] Connection reset by peer even for unrelated models without any sort of CustomDist

Okay it's something about the new progressbal and pycharm interactive python console. If I use from ipython/terminal it works. But also works in main for me?

ricardoV94 avatar May 27 '24 13:05 ricardoV94

I cannot reproduce a failure with your test locally (after avoding the pycharm issue) nor in a Colab environment: https://colab.research.google.com/drive/1I1n6c9IlmXknIfhxC5s7sAQghv0vfRSY?usp=sharing

Can you share more details about your environment/setup?

ricardoV94 avatar May 27 '24 13:05 ricardoV94

Can you share more details about your environment/setup?

I added the output of conda list to "PyMC version information" section of #7224. I'm running the code using VSCode if that matters. Do you need anything else?

Basically I followed the install instructions and the pull request tutorial when installing. Might have also pip installed a couple of extra packages here and there.

EliasRas avatar May 28 '24 07:05 EliasRas

I added the output of conda list to "PyMC version information" section of https://github.com/pymc-devs/pymc/issues/7224. I'm running the code using VSCode if that matters. Do you need anything else?

We should have at least one person reproduce the problem because I cannot. It may be a VSCode environment issue. Ideally we wouldn't have to change the codebase

ricardoV94 avatar May 28 '24 08:05 ricardoV94

The test does fail without the changes when I run it from miniforge prompt though.

EliasRas avatar May 28 '24 08:05 EliasRas

The test does fail without the changes when I run it from miniforge prompt though.

Not sure what miniforge prompt is, can we try to reproduce here on the CI then? Push just the test without the fixes into a new PR and well run it to see if we can reproduce

ricardoV94 avatar May 28 '24 08:05 ricardoV94

Is there anything that needs to be done here besides running the tests?

EliasRas avatar Jun 11 '24 07:06 EliasRas

Is there anything that needs to be done here besides running the tests?

Sorry for the delay, just kicked off tests.

twiecki avatar Jun 11 '24 11:06 twiecki

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Project coverage is 92.19%. Comparing base (c8b22df) to head (d2669f2). Report is 2 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #7241      +/-   ##
==========================================
+ Coverage   92.18%   92.19%   +0.01%     
==========================================
  Files         103      103              
  Lines       17259    17282      +23     
==========================================
+ Hits        15910    15933      +23     
  Misses       1349     1349              
Files Coverage Δ
pymc/smc/sampling.py 99.34% <100.00%> (+0.11%) :arrow_up:

codecov[bot] avatar Jun 11 '24 11:06 codecov[bot]

Thanks @EliasRas, I haven't been able to reproduce this yet but that's just because I'm in the middle of switching workstations and haven't gotten everything setup yet. Your fix looks fine to me and I understand what you identified as the cause of the issue: the dispatching mechanism isn't registering the logp and other methods to the dynamically created class. I think that this highlights a caveat in pymc's and pytensor's design: spawned processes may not have all the registered dispatch signatures as the main process. I imagine that this is mostly a problem on Windows, where multiprocessing can only spawn new processes whereas linux based systems will default to forks which in principle should copy over the memory contents of the main processes. I'm not sure what will happen under MacOS because I think that they cannot use fork multiprocessing for some reason either. With this design caveat in hand, I'm not sure if it's better to have a package level utility function that serves as a sort of book-keeper or something that can handle communicating the extra dispatch registration needed to ensure that child processes will use the correct dispatching functions. I'm curious to know what @ricardoV94 thinks about this. I don't think that this PR should have to tackle this kind of work, but I think that we can discuss if it's necessary here, and maybe later open an issue and a separate PR (also maybe in pytensor where dispatching is used for transpilation/compilation and maybe at some point for lazy gradients?).

lucianopaz avatar Jun 12 '24 09:06 lucianopaz

I guess the underlying reason for the failure is that pickling of DensityDist doesn't work out of the box? Sounds like for some reason the dispatch functions don't get registered when the object is unpickled. But wouldn't it be cleaner to overwrite the pickling behavior of this class then? We could override __getstate__ and __setstate__ methods to that effect?

aseyboldt avatar Jun 12 '24 12:06 aseyboldt

I guess the underlying reason for the failure is that pickling of DensityDist doesn't work out of the box?

I don’t think the problem is about pickling. The DensityDist end up returning an op that can be cloudpickled. If I recall correctly it can’t be pickled because the op class is created on the fly. In the process of creating the op, the dispatchers get populated with the callables that are supplied as inputs to the distribution class. As far as I understand, those functions are detached from the rv op and that’s why they never get populated on a spawned process.

lucianopaz avatar Jun 12 '24 15:06 lucianopaz

I don't mean that the pickling itself throws an error (it doesn't), but that it would be the responsibility of the DensityDist object to ensure that the set-up it needs (ie registering the logp) is done when it is unpickled.

For instance the following fails with the NotImplementedError, and has nothing to do with smc, so I guess the solution shouldn't be specific to smc?

import pymc as pm
import cloudpickle
import multiprocessing


def use_logp_func(pickled_model):
    model = cloudpickle.loads(pickled_model)
    logp = model.logp()
    func = pm.pytensorf.compile_pymc(model.value_vars, logp)
    print(func(1.0))


if __name__ == "__main__":
    with pm.Model() as model:

        def logp(value):
            return -(value**2)

        pm.DensityDist("x", logp=logp)

    logp = model.logp()
    func = pm.pytensorf.compile_pymc(model.value_vars, logp)
    pickled_model = cloudpickle.dumps(model)

    ctx = multiprocessing.get_context("spawn")
    process = ctx.Process(target=use_logp_func, args=(pickled_model,))
    process.start()
    process.join()

aseyboldt avatar Jun 12 '24 15:06 aseyboldt

I completely agree that this problem isn’t unique to smc and is a design caveat that needs to be addressed more comprehensively. I think that we can kind of patch some things:

  1. Make Model objects __setstate__ and __getstate__ repopulate the dispatch registries
  2. Get CustomDist rv ops to have these methods defined somehow (maybe clojures) that repopulates the dispatch registries.

I’m not sure if these two methods can cover all use patterns though.

lucianopaz avatar Jun 12 '24 19:06 lucianopaz

Alternatively we could pass the functions needed to each process which is more like what pm.sample does.

This also avoids recompiling the same functions multiple times?

ricardoV94 avatar Jun 12 '24 21:06 ricardoV94

@lucianopaz Point 1. is pretty straightforward but could you explain what you meant by 2.? How would it be different from overriding __getstate__ and __setstate__?

EliasRas avatar Jun 24 '24 08:06 EliasRas

I started work on __getstate__ and __setstate__ but I realized that I can't just copy the current implementation due to circular imports. Would either of these sound like a good idea?

  1. Create utils module (or something similar) in pymc.model that handles the custom methods and registering them.
  2. Add a method for RVs that registers the custom methods if necessary. CustomDistRV and CustomSymbolicDistRV would have the registrations and RandomVariable and SymbolicRandomVariable would just pass.

EliasRas avatar Jul 17 '24 07:07 EliasRas

I think we should explore an alternative where we compile the functions SMC needs and fork afterwards like pm.sample does.

This approach seems more brittle?

It would also avoid re-compiling the same functions in each chain

ricardoV94 avatar Jul 17 '24 07:07 ricardoV94

I agree on the re-compiling part but shouldn't this still be fixed? It feels like an arbitrary decision to "disallow" using multiprocessing this way only on Windows even if it is a bad way.

EliasRas avatar Jul 17 '24 07:07 EliasRas

I agree on the re-compiling part but shouldn't this still be fixed? It feels like an arbitrary decision to "disallow" using multiprocessing this way only on Windows even if it is a bad way.

I think this limitation is likely deeper than what you're addressing here. As @aseyboldt and @lucianopaz mentioned we're using dynamic dispatching as a recurring theme in our codebase and pytensor's

However, I don't agree with their solutions

ricardoV94 avatar Jul 17 '24 07:07 ricardoV94

Using the class that's being dispatched to register the dispatches during pickling seems at odds with the point of dispatching. The class shouldn't have to know what's being dispatched upon.

For instance we also have icdf methods, what if someone dispatched on it from the outside, does pickling work for it? Or would the setstate/getstate need to know about icdf (as well any other dispatch that may not even be part of PyMC)?

It's also not a PyMC model responsibility. CustomDist can be defined just fine outside of a model

ricardoV94 avatar Jul 17 '24 07:07 ricardoV94

Thank you for taking the time to explain. I'll start working on the compilation approach.

EliasRas avatar Jul 17 '24 07:07 EliasRas

However since this fixes existing behavior I think we can go ahead and merge it as a temporary patch?

ricardoV94 avatar Jul 17 '24 07:07 ricardoV94