pyro
pyro copied to clipboard
Add modified Bayesian regression tutorial with more direct PyTorch usage
This PR is an attempt at making a couple small changes to the Pyro and PyroModule
API to make PyroModule
more compatible with vanilla PyTorch programming idioms. The API changes are simple, although the implementations inside PyroModule
are a bit hacky and may not yet be correct.
Changes:
- Adds new global configuration toggle
pyro.enable_module_local_param()
forPyroModule
parameters to be stored locally, rather than the global parameter store. Currently implemented by associating a newParamStoreDict
object with eachPyroModule
instance, which may not be ideal. - Adds backwards-compatible
__call__
method topyro.infer.ELBO
that returns atorch.nn.Module
bound to a specific model and guide, allowing direct use of the PyTorch JIT API (e.g.torch.jit.trace
) - Forks Bayesian regression tutorial into a PyTorch API usage tutorial to illustrate a PyTorch-native programming style facilitated by these changes and
PyroModule
For context, here is a condensed training loop from the tutorial notebook that I was trying to enable:
# new: keep PyroParams out of the global parameter store
pyro.enable_module_local_param(True)
class BayesianRegression(PyroModule):
...
# Create fresh copies of model, guide, elbo
model = BayesianRegression(3, 1)
guide = AutoDiagonalNormal(model)
elbo = Trace_ELBO(num_particles=10)
# new: bind elbo to (model, guide) pair
elbo = elbo(model, guide)
# Populate guide parameters
elbo(x_data, y_data);
# new: use torch.optim directly
optim = torch.optim.Adam(guide.parameters(), lr=0.03)
# Temporarily disable runtime validation and compile ELBO
with pyro.validation_enabled(False):
# new: use torch.jit.trace directly
elbo = torch.jit.trace(elbo, (x_data, y_data), check_trace=False, strict=False)
# optimize
for j in range(1500):
loss = elbo(x_data, y_data)
optim.zero_grad()
loss.backward()
optim.step()
# prediction
predict_fn = Predictive(model, guide=guide, num_samples=800)
# new: use torch.jit.trace directly
predict_fn = torch.jit.trace(predict_fn, (x_data,), check_trace=False, strict=False)
samples = predict_fn(x_data)