pyro
pyro copied to clipboard
[Docs] Adding a doc page on debugging through a model (and notebook on Probabilistic PCA for tutorials)
Hi, I was preparing a tutorial on Prob. PCA mirroring the tutorial in TFP.
After non-trivial debugging , toying with to_event, I managed to get the example working in Pyro (though to be honest, still not very confident!).
I would like to suggest a tutorial which debugs some common issues in models (I'd say pyro.clear_params() and managing the shapes correct would be two such candidates). While there is already an excellent guide on Tensor shapes in Pyro, I think an example-driven tutorial page focused on showing the errors in modelling and then solving through the problems would be very useful (and potentially reduce similar questions on the forums).
As an example, I'm copying the code I used for Prob. PCA. I created two versions (with and without plates). Again, it may be nice in such a tutorial to discuss the differences in the two models and when to use which.
import pyro
import torch
import matplotlib.pyplot as plt
dist = pyro.distributions
torch.manual_seed(10)
W_gt = torch.rand(2, 2)
Z_gt = torch.randn(200, 2)
X = Z_gt@W_gt
plt.scatter(X[:, 0], X[:, 1])
plt.axis('equal');

pyro.clear_param_store()
def ppca_model_without_plate(data, latent_dim):
N, data_dim = data.shape
W = pyro.sample(
"W",
dist.Normal(
loc=torch.zeros([data_dim, latent_dim]),
scale=5.0 * torch.ones([data_dim, latent_dim]),
).to_event(2)
)
Z = pyro.sample(
"Z",
dist.Normal(
loc=torch.zeros([latent_dim, N]),
scale=torch.ones([latent_dim, N]),
).to_event(1),
)
mean = (W @ Z).t()
ob = pyro.distributions.Normal(mean, 1.0).to_event(2)
return pyro.sample("obs", ob, obs=data)
pyro.render_model(
ppca_model_without_plate, model_args=(X, 1), render_distributions=True
)
ppca_model_without_plate(X, 1).shape
/Users/nipun/miniforge3/lib/python3.9/site-packages/pyro/primitives.py:137: RuntimeWarning: trying to observe a value outside of inference at obs
warnings.warn(
torch.Size([200, 2])
import pyro.poutine as poutine
trace = poutine.trace(ppca_model_without_plate).get_trace(X, 1)
trace.compute_log_prob() # optional, but allows printing of log_prob shapes
print(trace.format_shapes())
Trace Shapes:
Param Sites:
Sample Sites:
W dist | 2 1
value | 2 1
log_prob |
Z dist 1 | 200
value 1 | 200
log_prob 1 |
obs dist | 200 2
value | 200 2
log_prob |
pyro.clear_param_store()
auto_guide = pyro.infer.autoguide.AutoNormal(ppca_model_without_plate)
trace = poutine.trace(auto_guide).get_trace(X, 1)
trace.compute_log_prob() # optional, but allows printing of log_prob shapes
print(trace.format_shapes())
Trace Shapes:
Param Sites:
AutoNormal.locs.W 2 1
AutoNormal.scales.W 2 1
AutoNormal.locs.Z 1 200
AutoNormal.scales.Z 1 200
Sample Sites:
W_unconstrained dist | 2 1
value | 2 1
log_prob |
W dist | 2 1
value | 2 1
log_prob |
Z_unconstrained dist 1 | 200
value 1 | 200
log_prob 1 |
Z dist 1 | 200
value 1 | 200
log_prob 1 |
import logging
adam = pyro.optim.Adam({"lr": 0.02}) # Consider decreasing learning rate.
elbo = pyro.infer.Trace_ELBO()
svi = pyro.infer.SVI(ppca_model_without_plate, auto_guide, adam, elbo)
losses = []
for step in range(1000): # Consider running for more steps.
loss = svi.step(X, 1)
losses.append(loss)
if step % 100 == 0:
logging.info("Elbo loss: {}".format(loss))
plt.figure(figsize=(5, 2))
plt.plot(losses)
plt.xlabel("SVI step")
plt.ylabel("ELBO loss");

pyro.clear_param_store()
def ppca_model(data, latent_dim):
N, data_dim = data.shape
W = pyro.sample(
"W",
dist.Normal(
loc=torch.zeros([data_dim, latent_dim]),
scale=5.0 * torch.ones([data_dim, latent_dim]),
).to_event(2),
)
with pyro.plate("data", len(data)):
z_n = pyro.sample("z", dist.Normal(loc=torch.zeros([1, latent_dim]), scale=torch.ones([1, latent_dim])))
mean = (W@z_n).t()
y = dist.Normal(mean, 1.).sample()
d = dist.Normal(mean, 1.)
e = d.to_event(1)
pyro.sample("obs", e, obs=data)
pyro.render_model(
ppca_model, model_args=(X, 1), render_distributions=True
)
auto_guide2 = pyro.infer.autoguide.AutoNormal(ppca_model)
pyro.clear_param_store()
adam = pyro.optim.Adam({"lr": 0.02}) # Consider decreasing learning rate.
elbo = pyro.infer.Trace_ELBO()
svi = pyro.infer.SVI(ppca_model, auto_guide2, adam, elbo)
losses = []
for step in range(1000): # Consider running for more steps.
loss = svi.step(X, 1)
losses.append(loss)
if step % 100 == 0:
logging.info("Elbo loss: {}".format(loss))
plt.figure(figsize=(5, 2))
plt.plot(losses)
plt.xlabel("SVI step")
plt.ylabel("ELBO loss");

Related #3030
I love the idea of a debugging-focused tutorial. And I suspect core Pyro devs are the worst people to write such a tutorial since it's been so long since we first stubbed our toes 🤣
@nipunbatra this would be a great tutorial to contribute!
Hi @fritzo Thanks.
I have added a long-ish notebook here
As you might notice, this will need some inputs from you especially on
- some shape related confusions
- best practices to specify shapes
- plate v/s non-plate model and when to use which
- any other debugging tips