pyro
pyro copied to clipboard
fix sequential enumeration
This is a pull request to fix the bug on the github issue.
with the same code, the sequential enumeration generates 0.6269999742507935
for 10000 infer_discrete
operation with temperature = 1
. Changing enum
variable to parallel
generates mean 0.6294000148773193
. Using temperature=0
for MAP estimation of the y_pre
will generate mean=1
for both parallel
and sequential
enumeration. These tests are on a GCP VM machine with a Ubuntu docker image.
import pyro
import pyro.distributions as dist
import torch
from pyro.infer import config_enumerate
from pyro.infer import infer_discrete
enum = "sequential"
@config_enumerate
def model(x_pa_obs=None, x_ch_obs=None, y_obs=None):
p = x_pa_obs
y = pyro.sample('y_pre', dist.Binomial(probs=p, total_count=1),
infer={"enumerate": enum},
obs=y_obs)
d_ch = dist.Normal(y, 1.0)
x_ch_pre = pyro.sample('x_ch_pre', d_ch, obs=x_ch_obs)
return y
data_obs = {'x_pa_obs': torch.tensor(0.5), 'x_ch_obs': torch.tensor(1.0)}
model_discrete = infer_discrete(model, first_available_dim=-1, temperature=1)
y_posts = []
for ii in range(10**4):
print(f'iteration {ii}', end='\r')
y_posts.append(model_discrete(**data_obs))
smpl = torch.stack(y_posts)
print(smpl.shape)
print(f"mean: {smpl.mean()}")
@qinqian thanks for looking at this. Can you turn the example above into a unit test in tests/infer/test_discrete.py
that fails without this fix and passes with it?
yes @eb8680 , I turned this example into a unit test in tests/infer/test_discrete.py
. added 2 tests that pass with it, and four cases that fail without this. To simulate example without the fix, the enum
was assigned to be other
.
Does anyone know why Github Actions are not running after new commits?
I think we may need permission to kick off the Github Actions. Thanks for your interest to the pull request. I tested it locally, it should work this time.
Sorry about the github bug, sometimes I've needed to close a PR and open another.
Aside from tests, can you explain your diagnosis of the problem and your proposed solution? From what I can tell, this PR amounts to "if the user says 'sequential' pretend they said 'parallel'", which seems wrong. But maybe I'm missing something.
Yes @fritzo. The question is: sequential enumerate generate different results from parallel enumeration. I use two ways to diagnosis of the problem.
The first is to add breakpoint to check the function here with the simple example above, and found that the key difference between the two enumerations is coming from the enum_terms
which is always empty for the sequential
enumeration, that means no elbo loss added for this term, it tends to be random elbo 0.5 for the simple example above. Then I track the function to the EnumMessenger (https://github.com/pyro-ppl/pyro/blob/bca99e296768f87fd8a51ef26e03d75867dc736b/pyro/poutine/enum_messenger.py#L123) class, and found that it did add loss for the sequential option.
The second way is to compare the pyro EnumMessenger with funsor, they use a similar way as I proposed above.
Please let me know if there is some misunderstanding of the problem.