pyro icon indicating copy to clipboard operation
pyro copied to clipboard

[feature request] Parallelism support for sequential plate/guide-side enumeration

Open amifalk opened this issue 1 year ago • 3 comments

For mixture models with arbitrary distributions over each feature, sampling currently must be done serially, even though these operations are trivially parallelizable.

To sample priors from a hierarchical mixture model with one continuous and one binary feature, you would need to do something like

with pyro.plate('components', n_components):
   for i in pyro.plate('features', 2):
      if i == 0:     
         pyro.sample('mu', dist.Normal(0, 1))
         pyro.sample('sigma_sq', dist.InverseGamma(1, 1))
      if i == 1:
         pyro.sample('theta', dist.Beta(.1, .1))

For mixture models with large number of features, this can become very slow.

I would love to be able to use a Joblib-like syntax for loops like these, i.e.

features = [['mu', dist.Normal(0, 1)], ['sigma_sq', dist.InverseGamma(1, 1)]],  ['theta', dist.Beta(.1, .1)]]

with pyro.plate('components', n_components):
   Parallel(n_jobs=-1)(delayed(sample_priors)(features[i]) for i in pyro.plate('features', 2)) 

I have tried something like this, and something about the Joblib backend and Pyro don't play nicely together-the model doesn't converge.

In a similar vein, adding parallelism for sequential guide-side enumeration could also enable dramatic speedups. For example, when trying to fit CrossCat with SVI and two truncated stick breaking processes over views and clusters (my personal use-case), enumerating out the view assignments in the model is not possible. Enumerating the views out in the guide is much too slow if they can't be done simultaneously over multiple cores. Since each model run doesn't share information with the others it seems like this should be possible in theory.

I realize this may be difficult for reasons mentioned in #2354, but is any parallelism like this possible in Pyro?

amifalk avatar May 22 '23 16:05 amifalk

Hmm, I'd guess the most straightforward approach to inter-distribution cpu parallelism would be to rely on the PyTorch jit by simply using JitTrace_ELBO or similar guide.

Pros:

  • it's a one-line change
  • let PyTorch systems folks solve the problem

Cons:

  • the PyTorch jit seems to break every other release, and doesn't seem engineered to work with large compute graphs as arise in Pyro
  • jit-traced models require fixed static model structure

fritzo avatar May 28 '23 23:05 fritzo

@amifalk Did you have any progress in this area? I'm facing with the same issue when dealing with model selection from a set of models with significantly different structure. I have a partial solution of using poutine.mask to mask out the log-likelihood parts in the model and guide trace from the models that are not currently selected with the discrete enumeration. Parallel enumeration can be used.

However, for complicated model structures and large set of models, the masking becomes quite complicated and prone to mistakes that can not be easily debugged.

pavleb avatar Dec 01 '23 17:12 pavleb

Sorry, no updates currently @pavleb. We ended up resolving speed issues by moving over to numpyro.

amifalk avatar Dec 01 '23 18:12 amifalk