pyro
pyro copied to clipboard
Implement Bayesian regression example from NumPyro in Pyro
This PR contributes the implementation of a Bayesian regression example / tutorial from NumPyro to Pyro. There is one issue that I am not able to address yet - I have to run the below cell before running the cells for Model 2 and Model 3. I did not face this issue with the NumPyro tutorial.
# Run NUTS
kernel = NUTS(model)
num_samples = 2000
mcmc = MCMC(kernel, num_samples=num_samples, warmup_steps=200)
Please let me know what you think @eb8680
Check out this pull request on
See visual diffs & provide feedback on Jupyter Notebooks.
Powered by ReviewNB
@arijc76 sorry I am so slow this week, great work! Here are some minor comments:
Notebook
- Could you use the same model function for SVI and MCMC, instead of having a separate
model_svi
? - Could you add a bit of descriptive text in the section on SVI? Perhaps a sentence or two before each of cell 20, 21 and 22 explaining what each cell is about to do and what the results mean
- Typo: "For this we one of" -> "For this we use one of"
- Could you use the
smoke_test
parameter to setnum_samples = 2
andnum_warmup = 2
, like you did withnum_iter
in cell 20?
Rendering
- Could you add your notebook's filename to
tutorial/source/index.rst
under the "Other inference algorithms" header? - Could you run
cd tutorial && make html
locally and check that everything in the notebook (e.g. math, images) is rendered correctly on your machine? The generated HTML files should appear intutorial/build/html/
- start a temporary local HTTP server in thetutorial
directory (e.g. viapython -m http.server
), openindex.html
in a browser and check that your example appears under the "other inference algorithms" header on the sidebar, then click the link to make sure it works and scroll through the generated HTML of your notebook and look for obvious visual errors
Testing
Your notebook is being executed correctly during CI (modulo reducing num_samples
when smoke_test == True
as requested above), but it is currently failing with the following error:
=================================== FAILURES ===================================
_______ tutorial/source/bayesian_regression_mcmc_and_svi.ipynb::Cell 21 ________
Notebook cell execution failed
Cell 21: Cell execution caused an exception
Input:
mcmc.run(
age=torch.tensor(dset.AgeScaled.values, dtype=torch.float),
divorce=torch.tensor(dset.DivorceScaled.values, dtype=torch.float)
)
mcmc.summary()
samples_2 = mcmc.get_samples()
Traceback:
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
<ipython-input-22-596822a70e0a> in <module>
1 mcmc.run(
2 age=torch.tensor(dset.AgeScaled.values, dtype=torch.float),
----> 3 divorce=torch.tensor(dset.DivorceScaled.values, dtype=torch.float)
4 )
5 mcmc.summary()
/opt/hostedtoolcache/Python/3.6.15/x64/lib/python3.6/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
10 def _context_wrap(context, fn, *args, **kwargs):
11 with context:
---> 12 return fn(*args, **kwargs)
13
14
/opt/hostedtoolcache/Python/3.6.15/x64/lib/python3.6/site-packages/pyro/infer/mcmc/api.py in run(self, *args, **kwargs)
561 # requires_grad", which happens with `jit_compile` under PyTorch 1.7
562 args = [arg.detach() if torch.is_tensor(arg) else arg for arg in args]
--> 563 for x, chain_id in self.sampler.run(*args, **kwargs):
564 if num_samples[chain_id] == 0:
565 num_samples[chain_id] += 1
/opt/hostedtoolcache/Python/3.6.15/x64/lib/python3.6/site-packages/pyro/infer/mcmc/api.py in run(self, *args, **kwargs)
228 i if self.num_chains > 1 else None,
229 *args,
--> 230 **kwargs
231 ):
232 yield sample, i # sample, chain_id
/opt/hostedtoolcache/Python/3.6.15/x64/lib/python3.6/site-packages/pyro/infer/mcmc/api.py in _gen_samples(kernel, warmup_steps, num_samples, hook, chain_id, *args, **kwargs)
142
143 def _gen_samples(kernel, warmup_steps, num_samples, hook, chain_id, *args, **kwargs):
--> 144 kernel.setup(warmup_steps, *args, **kwargs)
145 params = kernel.initial_params
146 save_params = getattr(kernel, "save_params", sorted(params))
/opt/hostedtoolcache/Python/3.6.15/x64/lib/python3.6/site-packages/pyro/infer/mcmc/hmc.py in setup(self, warmup_steps, *args, **kwargs)
323 self._warmup_steps = warmup_steps
324 if self.model is not None:
--> 325 self._initialize_model_properties(args, kwargs)
326 if self.initial_params:
327 z = {k: v.detach() for k, v in self.initial_params.items()}
/opt/hostedtoolcache/Python/3.6.15/x64/lib/python3.6/site-packages/pyro/infer/mcmc/hmc.py in _initialize_model_properties(self, model_args, model_kwargs)
267 skip_jit_warnings=self._ignore_jit_warnings,
268 init_strategy=self._init_strategy,
--> 269 initial_params=self._initial_params,
270 )
271 self.potential_fn = potential_fn
/opt/hostedtoolcache/Python/3.6.15/x64/lib/python3.6/site-packages/pyro/infer/mcmc/util.py in initialize_model(model, model_args, model_kwargs, transforms, max_plate_nesting, jit_compile, jit_options, skip_jit_warnings, num_chains, init_strategy, initial_params)
462
463 if initial_params is None:
--> 464 prototype_params = {k: transforms[k](v) for k, v in prototype_samples.items()}
465 # Note that we deliberately do not exercise jit compilation here so as to
466 # enable potential_fn to be picklable (a torch._C.Function cannot be pickled).
/opt/hostedtoolcache/Python/3.6.15/x64/lib/python3.6/site-packages/pyro/infer/mcmc/util.py in <dictcomp>(.0)
462
463 if initial_params is None:
--> 464 prototype_params = {k: transforms[k](v) for k, v in prototype_samples.items()}
465 # Note that we deliberately do not exercise jit compilation here so as to
466 # enable potential_fn to be picklable (a torch._C.Function cannot be pickled).
KeyError: 'bA'
I have to run the below cell before running the cells for Model 2 and Model 3.
There may be some difference between the MCMC
objects in Pyro and NumPyro. Could you try creating separate instances for each model rather than reusing the same one?
The custom predict_fn
in your updated "Predictive Utility With Effect Handlers" section can be simplified - the pyro.plate("samples", 2000)
you added plays the same role as Jax's vmap
in this case, so you should be able to write something like this in cell 12 that more directly follows the original structure:
def predict(post_samples, model, *args, **kwargs):
conditioned_model = poutine.condition(model, post_samples)
model_trace = poutine.trace(conditioned_model).get_trace(*args, **kwargs)
return model_trace.nodes["obs"]["value"]
def predict_fn(post_samples):
with pyro.plate("samples", num_samples):
return predict(post_samples, model, marriage=torch.tensor(dset.MarriageScaled.values, dtype=torch.float))
The custom
predict_fn
in your updated "Predictive Utility With Effect Handlers" section can be simplified - thepyro.plate("samples", 2000)
you added plays the same role as Jax'svmap
in this case, so you should be able to write something like this in cell 12 that more directly follows the original structure:def predict(post_samples, model, *args, **kwargs): conditioned_model = poutine.condition(model, post_samples) model_trace = poutine.trace(conditioned_model).get_trace(*args, **kwargs) return model_trace.nodes["obs"]["value"] def predict_fn(post_samples): with pyro.plate("samples", num_samples): return predict(post_samples, model, marriage=torch.tensor(dset.MarriageScaled.values, dtype=torch.float))
@eb8680 Thanks for this suggestion.
cd tutorial && make html
@eb8680 When I run cd tutorial && make html
, the build stops at waiting for workers...
. Then after some time I get the error as shown below. Can you let me know what I need to do for this? Thanks.
Running Sphinx v4.4.0
building [mo]: targets for 0 po files that are out of date
building [html]: targets for 78 source files that are out of date
updating environment: [new config] 78 added, 0 changed, 0 removed
reading sources... [100%] tensor_shapes .. working_memory
waiting for workers...
Warning, treated as error:
/Users/arijeetchatterjee/Documents/github_personal_projects/pyro/tutorial/source/ss-vae.ipynb:7:Unexpected indentation.
make: *** [html] Error 2
Can you let me know what I need to do for this?
I can't reproduce your error, but you can tell Sphinx not to treat warnings as errors by overriding the SPHINXOPTS
environment variable used in our Makefile:
SPHINXOPTS="-E -j 8" make html
Can you let me know what I need to do for this?
I can't reproduce your error, but you can tell Sphinx not to treat warnings as errors by overriding the
SPHINXOPTS
environment variable used in our Makefile:SPHINXOPTS="-E -j 8" make html
Thanks @eb8680
Sorry about the delay. I was facing some issues with make html
not working locally, but that's solved now (made a couple of changes in conf.py
along with above suggested change for SPHINXOPTS in the Makefile).
I have completed the suggested changes and now when I run make html
to render the HTML, the example appears under the "other inference algorithms" header on the sidebar. The generated HTML of the notebook does not show any visual errors.
Can I commit only the updated version of the notebook for a review?
[UPDATED] I have committed the notebook with the changes as mentioned above. Please take a look. Thanks.