pyro icon indicating copy to clipboard operation
pyro copied to clipboard

🐛 LocScaleReparam and enumeration with NUTS

Open pietrolesci opened this issue 3 years ago • 4 comments

Context

Hi there,

While trying to reproduce the annotators.py example - originally written in numpyro - in pyro, I've come across what might be a bug. Discussion on pyro forum here. As suggested by @fehiepsi, I am opening this issue.

Description

LogScaleReparam seems to create problems when coupled with the automatic enumeration of discrete variables in the context of MCMC inference with NUTS. The same issue is not present when the reparametrization is done manually. Below I report the code to reproduce the issue.

Code

import pyro
import pyro.distributions as dist
import torch
import torch.nn.functional as F
from pyro.infer import MCMC, NUTS
from pyro.infer.reparam import LocScaleReparam
from pyro.ops.indexing import Vindex
from pyro.poutine import reparam


def get_data():
    """
    :return: a tuple of annotator indices and class indices. The first term has shape
        `num_positions` whose entries take values from `0` to `num_annotators - 1`.
        The second term has shape `num_items x num_positions` whose entries take values
        from `0` to `num_classes - 1`.
    """
    # NB: the first annotator assessed each item 3 times
    positions = torch.tensor([1, 1, 1, 2, 3, 4, 5])
    annotations = torch.tensor(
        [
            [1, 1, 1, 1, 1, 1, 1],
            [3, 3, 3, 4, 3, 3, 4],
            [1, 1, 2, 2, 1, 2, 2],
            [2, 2, 2, 3, 1, 2, 1],
            [2, 2, 2, 3, 2, 2, 2],
            [2, 2, 2, 3, 3, 2, 2],
        ]
    )
    # we minus 1 because in Python, the first index is 0
    return positions - 1, annotations - 1


def hierarchical_dawid_skene(positions: torch.Tensor, annotations: torch.Tensor) -> None:
    """
    This model corresponds to the plate diagram in Figure 4 of reference [1].
    """
    num_annotators = positions.unique().numel()
    num_classes = annotations.unique().numel()
    num_items, num_positions = annotations.shape
    # print(f"{num_classes=}, {num_annotators=}, {num_items=}, {num_positions=}")

    with pyro.plate("class", num_classes):
        # NB: we define `beta` as the `logits` of `y` likelihood; but `logits` is
        # invariant up to a constant, so we'll follow [1]: fix the last term of `beta`
        # to 0 and only define hyperpriors for the first `num_classes - 1` terms.
        zeta = pyro.sample("zeta", dist.Normal(0, 1).expand([num_classes - 1]).to_event(1))
        omega = pyro.sample("Omega", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1))

    with pyro.plate("annotator", num_annotators, dim=-2):
        with pyro.plate("class_abilities", num_classes):
            # non-centered parameterization
            with reparam(config={"beta": LocScaleReparam(centered=0.0)}):
                beta = pyro.sample("beta", dist.Normal(zeta, omega).to_event(1))

            # NOTE: in this way it works
            # beta_base = pyro.sample("beta_base", dist.Normal(0., 1.).expand([num_classes - 1]).to_event(1))
            # beta = beta_base * omega + zeta

            # pad 0 last dimension
            beta = F.pad(beta, [0, 1] + [0, 0] * (beta.dim() - 1))

    pi = pyro.sample("pi", dist.Dirichlet(torch.ones(num_classes)))

    with pyro.plate("item", num_items, dim=-2):
        c = pyro.sample("c", dist.Categorical(probs=pi))

        with pyro.plate("position", num_positions):
            logits = Vindex(beta)[positions, c, :]
            pyro.sample("y", dist.Categorical(logits=logits), obs=annotations)


if __name__ == "__main__":

    model = hierarchical_dawid_skene
    data = get_data()

    mcmc = MCMC(
        NUTS(model),
        warmup_steps=500,
        num_samples=10_000,
        num_chains=1,
    )
    mcmc.run(*data)

Details

In the first MCMC iteration with NUTS (following the numpyro example), I get the following shapes

num_classes=4, num_annotators=5, num_items=10, num_positions=7
c.shape=torch.Size([10, 1]), beta.shape=torch.Size([5, 4, 4])

In the second iteration, when c is enumerated, I get the following shapes

num_classes=4, num_annotators=5, num_items=10, num_positions=7
c.shape=torch.Size([4, 1, 1]), beta.shape=torch.Size([4, 4])

Notably, beta is changed from (5, 4, 4) to (4, 4). This does not happen when I remove the reparametrization.

Watermark

  • python 3.8
# Name                    Version                   Build  Channel
numpyro                   0.8.0                    pypi_0    pypi
pyro-api                  0.1.2                    pypi_0    pypi
pyro-ppl                  1.7.0                    pypi_0    pypi

pietrolesci avatar Nov 18 '21 11:11 pietrolesci

Hi @pietrolesci , it turns out that you would need to apply reparam outside of plate statements

    with reparam(config={"beta": LocScaleReparam(centered=0.0)}):
        with pyro.plate("annotator", num_annotators, dim=-2):
            with pyro.plate("class_abilities", num_classes):
                ...

I think it is a better practice than putting reparam inside plate. But we will try to add a warning message to help users fix the model.

fehiepsi avatar Dec 11 '21 13:12 fehiepsi

Hi @fehiepsi,

Thanks for your feedback. I'll try that out asap. Is this mentioned in the docs?

Also, is this true for numpyro as well? In case it is, maybe it makes sense to update the annotators examples accordingly in the doc.

Again, thanks for looking into this.

pietrolesci avatar Dec 11 '21 14:12 pietrolesci

For numpyro it is currently not required. But to be safe (in case we change the behavior of init strategies in numpyro in the future), you can follow the above practice.

fehiepsi avatar Dec 11 '21 16:12 fehiepsi

Thanks a lot @fehiepsi for your feedback!

pietrolesci avatar Dec 13 '21 14:12 pietrolesci