pyro icon indicating copy to clipboard operation
pyro copied to clipboard

[Docs] Adding a doc page on debugging through a model (and notebook on Probabilistic PCA for tutorials)

Open nipunbatra opened this issue 3 years ago • 2 comments

Hi, I was preparing a tutorial on Prob. PCA mirroring the tutorial in TFP.

After non-trivial debugging , toying with to_event, I managed to get the example working in Pyro (though to be honest, still not very confident!).

I would like to suggest a tutorial which debugs some common issues in models (I'd say pyro.clear_params() and managing the shapes correct would be two such candidates). While there is already an excellent guide on Tensor shapes in Pyro, I think an example-driven tutorial page focused on showing the errors in modelling and then solving through the problems would be very useful (and potentially reduce similar questions on the forums).

As an example, I'm copying the code I used for Prob. PCA. I created two versions (with and without plates). Again, it may be nice in such a tutorial to discuss the differences in the two models and when to use which.

import pyro
import torch
import matplotlib.pyplot as plt
dist = pyro.distributions
torch.manual_seed(10)
W_gt = torch.rand(2, 2)
Z_gt = torch.randn(200, 2)

X = Z_gt@W_gt
plt.scatter(X[:, 0], X[:, 1])
plt.axis('equal');

output_3_0

pyro.clear_param_store()


def ppca_model_without_plate(data, latent_dim):
    N, data_dim = data.shape
    W = pyro.sample(
        "W",
        dist.Normal(
            loc=torch.zeros([data_dim, latent_dim]),
            scale=5.0 * torch.ones([data_dim, latent_dim]),
        ).to_event(2)
    )
    Z = pyro.sample(
        "Z",
        dist.Normal(
            loc=torch.zeros([latent_dim, N]),
            scale=torch.ones([latent_dim, N]),
        ).to_event(1),
    )

    mean = (W @ Z).t()
    
    ob = pyro.distributions.Normal(mean, 1.0).to_event(2)
    return pyro.sample("obs", ob, obs=data)


pyro.render_model(
    ppca_model_without_plate, model_args=(X, 1), render_distributions=True
)

output_4_0

ppca_model_without_plate(X, 1).shape
/Users/nipun/miniforge3/lib/python3.9/site-packages/pyro/primitives.py:137: RuntimeWarning: trying to observe a value outside of inference at obs
  warnings.warn(





torch.Size([200, 2])
import pyro.poutine as poutine

trace = poutine.trace(ppca_model_without_plate).get_trace(X, 1)
trace.compute_log_prob()  # optional, but allows printing of log_prob shapes
print(trace.format_shapes())
Trace Shapes:          
 Param Sites:          
Sample Sites:          
       W dist   |   2 1
        value   |   2 1
     log_prob   |      
       Z dist 1 | 200  
        value 1 | 200  
     log_prob 1 |      
     obs dist   | 200 2
        value   | 200 2
     log_prob   |      
pyro.clear_param_store()
auto_guide = pyro.infer.autoguide.AutoNormal(ppca_model_without_plate)
trace = poutine.trace(auto_guide).get_trace(X, 1)
trace.compute_log_prob()  # optional, but allows printing of log_prob shapes
print(trace.format_shapes())
       Trace Shapes:            
        Param Sites:            
   AutoNormal.locs.W 2   1      
 AutoNormal.scales.W 2   1      
   AutoNormal.locs.Z 1 200      
 AutoNormal.scales.Z 1 200      
       Sample Sites:            
W_unconstrained dist     |   2 1
               value     |   2 1
            log_prob     |      
              W dist     |   2 1
               value     |   2 1
            log_prob     |      
Z_unconstrained dist 1   | 200  
               value 1   | 200  
            log_prob 1   |      
              Z dist 1   | 200  
               value 1   | 200  
            log_prob 1   |      
import logging
adam = pyro.optim.Adam({"lr": 0.02})  # Consider decreasing learning rate.
elbo = pyro.infer.Trace_ELBO()
svi = pyro.infer.SVI(ppca_model_without_plate, auto_guide, adam, elbo)

losses = []
for step in range(1000):  # Consider running for more steps.
    loss = svi.step(X, 1)
    losses.append(loss)
    if step % 100 == 0:
        logging.info("Elbo loss: {}".format(loss))

plt.figure(figsize=(5, 2))
plt.plot(losses)
plt.xlabel("SVI step")
plt.ylabel("ELBO loss");

output_9_0

pyro.clear_param_store()


def ppca_model(data, latent_dim):
    N, data_dim = data.shape
    W = pyro.sample(
        "W",
        dist.Normal(
            loc=torch.zeros([data_dim, latent_dim]),
            scale=5.0 * torch.ones([data_dim, latent_dim]),
        ).to_event(2),
    )
   
    with pyro.plate("data", len(data)):
        z_n = pyro.sample("z", dist.Normal(loc=torch.zeros([1, latent_dim]), scale=torch.ones([1, latent_dim])))

        mean = (W@z_n).t()
        y = dist.Normal(mean, 1.).sample()
        d = dist.Normal(mean, 1.)
        e = d.to_event(1)
        pyro.sample("obs", e, obs=data)




pyro.render_model(
    ppca_model, model_args=(X, 1), render_distributions=True
)

output_10_0

auto_guide2 = pyro.infer.autoguide.AutoNormal(ppca_model)
pyro.clear_param_store()
adam = pyro.optim.Adam({"lr": 0.02})  # Consider decreasing learning rate.
elbo = pyro.infer.Trace_ELBO()
svi = pyro.infer.SVI(ppca_model, auto_guide2, adam, elbo)

losses = []
for step in range(1000):  # Consider running for more steps.
    loss = svi.step(X, 1)
    losses.append(loss)
    if step % 100 == 0:
        logging.info("Elbo loss: {}".format(loss))

plt.figure(figsize=(5, 2))
plt.plot(losses)
plt.xlabel("SVI step")
plt.ylabel("ELBO loss");

output_12_0

Related #3030

nipunbatra avatar Mar 01 '22 10:03 nipunbatra

I love the idea of a debugging-focused tutorial. And I suspect core Pyro devs are the worst people to write such a tutorial since it's been so long since we first stubbed our toes 🤣

@nipunbatra this would be a great tutorial to contribute!

fritzo avatar Mar 01 '22 20:03 fritzo

Hi @fritzo Thanks.

I have added a long-ish notebook here

As you might notice, this will need some inputs from you especially on

  • some shape related confusions
  • best practices to specify shapes
  • plate v/s non-plate model and when to use which
  • any other debugging tips

nipunbatra avatar Mar 08 '22 09:03 nipunbatra