pyro icon indicating copy to clipboard operation
pyro copied to clipboard

Add modified Bayesian regression tutorial with more direct PyTorch usage

Open eb8680 opened this issue 3 years ago • 1 comments

This PR is an attempt at making a couple small changes to the Pyro and PyroModule API to make PyroModule more compatible with vanilla PyTorch programming idioms. The API changes are simple, although the implementations inside PyroModule are a bit hacky and may not yet be correct.

Changes:

  • Adds new global configuration toggle pyro.enable_module_local_param() for PyroModule parameters to be stored locally, rather than the global parameter store. Currently implemented by associating a new ParamStoreDict object with each PyroModule instance, which may not be ideal.
  • Adds backwards-compatible __call__ method to pyro.infer.ELBO that returns a torch.nn.Module bound to a specific model and guide, allowing direct use of the PyTorch JIT API (e.g. torch.jit.trace)
  • Forks Bayesian regression tutorial into a PyTorch API usage tutorial to illustrate a PyTorch-native programming style facilitated by these changes and PyroModule

eb8680 avatar Dec 14 '21 15:12 eb8680

For context, here is a condensed training loop from the tutorial notebook that I was trying to enable:

# new: keep PyroParams out of the global parameter store
pyro.enable_module_local_param(True)

class BayesianRegression(PyroModule):
    ...

# Create fresh copies of model, guide, elbo
model = BayesianRegression(3, 1)
guide = AutoDiagonalNormal(model)
elbo = Trace_ELBO(num_particles=10)

# new: bind elbo to (model, guide) pair
elbo = elbo(model, guide)

# Populate guide parameters
elbo(x_data, y_data);
# new: use torch.optim directly
optim = torch.optim.Adam(guide.parameters(), lr=0.03)

# Temporarily disable runtime validation and compile ELBO
with pyro.validation_enabled(False):
    # new: use torch.jit.trace directly
    elbo = torch.jit.trace(elbo, (x_data, y_data), check_trace=False, strict=False)

# optimize
for j in range(1500):
    loss = elbo(x_data, y_data)
    optim.zero_grad()
    loss.backward()
    optim.step()

# prediction
predict_fn = Predictive(model, guide=guide, num_samples=800)
# new: use torch.jit.trace directly
predict_fn = torch.jit.trace(predict_fn, (x_data,), check_trace=False, strict=False)
samples = predict_fn(x_data)

eb8680 avatar Dec 14 '21 16:12 eb8680