pyro icon indicating copy to clipboard operation
pyro copied to clipboard

Overly strict validation logic during guide enumeration

Open eb8680 opened this issue 3 years ago • 0 comments

The error-checking logic in check_site_shape in TraceEnum_ELBO incorrectly raises an error when performing guide-side enumeration in a model/guide pair where the factorization differs between model and guide such that the guide introduces an extra dependency. This came up in @qinqian's project.

Consider the following trivial model/guide pair:

import torch
import pyro

import pyro.distributions as dist
from pyro.infer import TraceEnum_ELBO, config_enumerate


def model():
    x = pyro.sample("x", dist.Normal(0, 1))
    p = torch.where(x > 0, torch.tensor(0.8), torch.tensor(0.3))
    i = pyro.sample("i", dist.Bernoulli(p))


@config_enumerate
def guide():
    p = pyro.param("p", lambda: torch.tensor(0.3))
    i = pyro.sample("i", dist.Bernoulli(p)).long()
    x = pyro.sample("x", dist.Normal(torch.tensor([-0.5, 1.])[..., i], 1))


TraceEnum_ELBO().differentiable_loss(model, guide)

This snippet is fine and should be allowed, but currently its execution will produce the following error:

Traceback (most recent call last):
  File "error_example.py", line 21, in <module>
    TraceEnum_ELBO().differentiable_loss(model, guide)
  File "/Users/ebingham/development/pyro/pyro/infer/traceenum_elbo.py", line 425, in differentiable_loss
    for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
  File "/Users/ebingham/development/pyro/pyro/infer/traceenum_elbo.py", line 392, in _get_traces
    yield self._get_trace(model, guide, args, kwargs)
  File "/Users/ebingham/development/pyro/pyro/infer/traceenum_elbo.py", line 340, in _get_trace
    "flat", self.max_plate_nesting, model, guide, args, kwargs
  File "/Users/ebingham/development/pyro/pyro/infer/enum.py", line 69, in get_importance_trace
    check_site_shape(site, max_plate_nesting)
  File "/Users/ebingham/development/pyro/pyro/util.py", line 439, in check_site_shape
    "Try increasing pyro.markov history size",
ValueError: Enumeration dim conflict at site "i"
  Try increasing pyro.markov history size

eb8680 avatar Aug 26 '21 16:08 eb8680