pyro icon indicating copy to clipboard operation
pyro copied to clipboard

fix sequential enumeration

Open qinqian opened this issue 1 year ago • 6 comments

This is a pull request to fix the bug on the github issue.

with the same code, the sequential enumeration generates 0.6269999742507935 for 10000 infer_discrete operation with temperature = 1. Changing enum variable to parallel generates mean 0.6294000148773193. Using temperature=0 for MAP estimation of the y_pre will generate mean=1 for both parallel and sequential enumeration. These tests are on a GCP VM machine with a Ubuntu docker image.

import pyro
import pyro.distributions as dist
import torch
from pyro.infer import config_enumerate
from pyro.infer import infer_discrete

enum = "sequential" 

@config_enumerate
def model(x_pa_obs=None, x_ch_obs=None, y_obs=None):
    p = x_pa_obs
    y = pyro.sample('y_pre', dist.Binomial(probs=p, total_count=1),
                    infer={"enumerate": enum},
                    obs=y_obs)

    d_ch = dist.Normal(y, 1.0)
    x_ch_pre = pyro.sample('x_ch_pre', d_ch, obs=x_ch_obs)

    return y


data_obs = {'x_pa_obs': torch.tensor(0.5), 'x_ch_obs': torch.tensor(1.0)}
model_discrete = infer_discrete(model, first_available_dim=-1, temperature=1)

y_posts = []
for ii in range(10**4):
    print(f'iteration {ii}', end='\r')
    y_posts.append(model_discrete(**data_obs))

smpl = torch.stack(y_posts)
print(smpl.shape)
print(f"mean: {smpl.mean()}")

qinqian avatar Jul 13 '23 01:07 qinqian

@qinqian thanks for looking at this. Can you turn the example above into a unit test in tests/infer/test_discrete.py that fails without this fix and passes with it?

eb8680 avatar Jul 14 '23 14:07 eb8680

yes @eb8680 , I turned this example into a unit test in tests/infer/test_discrete.py. added 2 tests that pass with it, and four cases that fail without this. To simulate example without the fix, the enum was assigned to be other.

qinqian avatar Jul 17 '23 03:07 qinqian

Does anyone know why Github Actions are not running after new commits?

ordabayevy avatar Oct 04 '23 12:10 ordabayevy

I think we may need permission to kick off the Github Actions. Thanks for your interest to the pull request. I tested it locally, it should work this time.

qinqian avatar Oct 05 '23 01:10 qinqian

Sorry about the github bug, sometimes I've needed to close a PR and open another.

Aside from tests, can you explain your diagnosis of the problem and your proposed solution? From what I can tell, this PR amounts to "if the user says 'sequential' pretend they said 'parallel'", which seems wrong. But maybe I'm missing something.

fritzo avatar Oct 05 '23 01:10 fritzo

Yes @fritzo. The question is: sequential enumerate generate different results from parallel enumeration. I use two ways to diagnosis of the problem.

The first is to add breakpoint to check the function here with the simple example above, and found that the key difference between the two enumerations is coming from the enum_terms which is always empty for the sequential enumeration, that means no elbo loss added for this term, it tends to be random elbo 0.5 for the simple example above. Then I track the function to the EnumMessenger (https://github.com/pyro-ppl/pyro/blob/bca99e296768f87fd8a51ef26e03d75867dc736b/pyro/poutine/enum_messenger.py#L123) class, and found that it did add loss for the sequential option.

The second way is to compare the pyro EnumMessenger with funsor, they use a similar way as I proposed above.

Please let me know if there is some misunderstanding of the problem.

qinqian avatar Oct 05 '23 15:10 qinqian