pyro
pyro copied to clipboard
Overly strict validation logic during guide enumeration
The error-checking logic in check_site_shape
in TraceEnum_ELBO
incorrectly raises an error when performing guide-side enumeration in a model/guide pair where the factorization differs between model and guide such that the guide introduces an extra dependency. This came up in @qinqian's project.
Consider the following trivial model/guide pair:
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import TraceEnum_ELBO, config_enumerate
def model():
x = pyro.sample("x", dist.Normal(0, 1))
p = torch.where(x > 0, torch.tensor(0.8), torch.tensor(0.3))
i = pyro.sample("i", dist.Bernoulli(p))
@config_enumerate
def guide():
p = pyro.param("p", lambda: torch.tensor(0.3))
i = pyro.sample("i", dist.Bernoulli(p)).long()
x = pyro.sample("x", dist.Normal(torch.tensor([-0.5, 1.])[..., i], 1))
TraceEnum_ELBO().differentiable_loss(model, guide)
This snippet is fine and should be allowed, but currently its execution will produce the following error:
Traceback (most recent call last):
File "error_example.py", line 21, in <module>
TraceEnum_ELBO().differentiable_loss(model, guide)
File "/Users/ebingham/development/pyro/pyro/infer/traceenum_elbo.py", line 425, in differentiable_loss
for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
File "/Users/ebingham/development/pyro/pyro/infer/traceenum_elbo.py", line 392, in _get_traces
yield self._get_trace(model, guide, args, kwargs)
File "/Users/ebingham/development/pyro/pyro/infer/traceenum_elbo.py", line 340, in _get_trace
"flat", self.max_plate_nesting, model, guide, args, kwargs
File "/Users/ebingham/development/pyro/pyro/infer/enum.py", line 69, in get_importance_trace
check_site_shape(site, max_plate_nesting)
File "/Users/ebingham/development/pyro/pyro/util.py", line 439, in check_site_shape
"Try increasing pyro.markov history size",
ValueError: Enumeration dim conflict at site "i"
Try increasing pyro.markov history size