pyro
pyro copied to clipboard
[bug] incorrect discrete inference with `sequential`enumeration.
While the discrete inference results I get for parallel
enumeration are accurate, the results for sequential
enumeration are not. In theory, both should return the same result. source
I created a minimal working example to demonstrate the problem. I.e. for parallel
enumeration 0.62 is returned, and for sequential
it is ca. 0.5.
import pyro
import pyro.distributions as dist
import torch
from pyro.infer import config_enumerate
from pyro.infer import infer_discrete
@config_enumerate
def model(x_pa_obs=None, x_ch_obs=None, y_obs=None):
p = x_pa_obs
y = pyro.sample('y_pre', dist.Binomial(probs=p, total_count=1),
infer={"enumerate": "sequential"},
obs=y_obs)
d_ch = dist.Normal(y, 1.0)
x_ch_pre = pyro.sample('x_ch_pre', d_ch, obs=x_ch_obs)
return y
data_obs = {'x_pa_obs': torch.tensor(0.5), 'x_ch_obs': torch.tensor(1.0)}
model_discrete = infer_discrete(model, first_available_dim=-1, temperature=1)
y_posts = []
for ii in range(10**4):
print(f'iteration {ii}', end='\r')
y_posts.append(model_discrete(**data_obs))
smpl = torch.stack(y_posts)
print(f"mean: {smpl.mean()}")
This pull request (https://github.com/pyro-ppl/pyro/pull/3238) should fix the bug. @eb8680