pyro
pyro copied to clipboard
[Bug] Dimension augmenting with Predictive in Africa Tutorial
Issue Description
In Jupyter, hitting several times the following code
pred = infer.Prpred = infer.Predictive(model, guide=auto_guide, num_samples=2)
svi_samples = pred(is_cont_africa, ruggedness, log_gdp)
log_gdp = svi_samples['obs']
print(log_gdp.shape)
The shape goes crazy, i.e.
torch.Size([2, 170])
torch.Size([2, 2, 170])
torch.Size([2, 2, 2, 170])
torch.Size([2, 2, 2, 2, 170])...
Environment
- OS and python version: Kaggle (linux), Python 3.7.12
- PyTorch version: 1.11.0
- Pyro version: output of 1.8.1
Code Snippet
Already provided above.
Hope to help some.
Found the error. The variable log_gdp was being referenced with each cell. Solution: Just rename it.
Working example:
pred = infer.Predictive(model, guide=auto_guide, num_samples=20)
svi_samples = pred(is_cont_africa, ruggedness, log_gdp)
log_gdp_ = svi_samples['obs']
log_gdp_.shape
Still hope to help.