pyro icon indicating copy to clipboard operation
pyro copied to clipboard

[bug] incorrect discrete inference with `sequential`enumeration.

Open gcskoenig opened this issue 2 years ago • 1 comments

While the discrete inference results I get for parallel enumeration are accurate, the results for sequential enumeration are not. In theory, both should return the same result. source

I created a minimal working example to demonstrate the problem. I.e. for parallel enumeration 0.62 is returned, and for sequential it is ca. 0.5.

import pyro
import pyro.distributions as dist
import torch
from pyro.infer import config_enumerate
from pyro.infer import infer_discrete

@config_enumerate
def model(x_pa_obs=None, x_ch_obs=None, y_obs=None):
    p = x_pa_obs
    y = pyro.sample('y_pre', dist.Binomial(probs=p, total_count=1),
                        infer={"enumerate": "sequential"},
                        obs=y_obs)

    d_ch = dist.Normal(y, 1.0)
    x_ch_pre = pyro.sample('x_ch_pre', d_ch, obs=x_ch_obs)

    return y

data_obs = {'x_pa_obs': torch.tensor(0.5), 'x_ch_obs': torch.tensor(1.0)}
model_discrete = infer_discrete(model, first_available_dim=-1, temperature=1)

y_posts = []
for ii in range(10**4):
    print(f'iteration {ii}', end='\r')
    y_posts.append(model_discrete(**data_obs))

smpl = torch.stack(y_posts)
print(f"mean: {smpl.mean()}")

gcskoenig avatar May 03 '22 17:05 gcskoenig

This pull request (https://github.com/pyro-ppl/pyro/pull/3238) should fix the bug. @eb8680

qinqian avatar Jul 13 '23 01:07 qinqian