pyro icon indicating copy to clipboard operation
pyro copied to clipboard

[Bug] Dimension augmenting with Predictive in Africa Tutorial

Open maulberto3 opened this issue 2 years ago • 1 comments

Issue Description

In Jupyter, hitting several times the following code

pred = infer.Prpred = infer.Predictive(model, guide=auto_guide, num_samples=2)
svi_samples = pred(is_cont_africa, ruggedness, log_gdp)
log_gdp = svi_samples['obs']
print(log_gdp.shape)

The shape goes crazy, i.e.

torch.Size([2, 170])
torch.Size([2, 2, 170])
torch.Size([2, 2, 2, 170])
torch.Size([2, 2, 2, 2, 170])...

Environment

  • OS and python version: Kaggle (linux), Python 3.7.12
  • PyTorch version: 1.11.0
  • Pyro version: output of 1.8.1

Code Snippet

Already provided above.

Hope to help some.

maulberto3 avatar May 02 '22 05:05 maulberto3

Found the error. The variable log_gdp was being referenced with each cell. Solution: Just rename it.

Working example:

pred = infer.Predictive(model, guide=auto_guide, num_samples=20)
svi_samples = pred(is_cont_africa, ruggedness, log_gdp)
log_gdp_ = svi_samples['obs']
log_gdp_.shape

Still hope to help.

maulberto3 avatar May 02 '22 05:05 maulberto3