pyro
pyro copied to clipboard
🐛 LocScaleReparam and enumeration with NUTS
Context
Hi there,
While trying to reproduce the annotators.py example - originally written in numpyro - in pyro, I've come across what might be a bug. Discussion on pyro forum here. As suggested by @fehiepsi, I am opening this issue.
Description
LogScaleReparam
seems to create problems when coupled with the automatic enumeration of discrete variables in the context of MCMC inference with NUTS. The same issue is not present when the reparametrization is done manually.
Below I report the code to reproduce the issue.
Code
import pyro
import pyro.distributions as dist
import torch
import torch.nn.functional as F
from pyro.infer import MCMC, NUTS
from pyro.infer.reparam import LocScaleReparam
from pyro.ops.indexing import Vindex
from pyro.poutine import reparam
def get_data():
"""
:return: a tuple of annotator indices and class indices. The first term has shape
`num_positions` whose entries take values from `0` to `num_annotators - 1`.
The second term has shape `num_items x num_positions` whose entries take values
from `0` to `num_classes - 1`.
"""
# NB: the first annotator assessed each item 3 times
positions = torch.tensor([1, 1, 1, 2, 3, 4, 5])
annotations = torch.tensor(
[
[1, 1, 1, 1, 1, 1, 1],
[3, 3, 3, 4, 3, 3, 4],
[1, 1, 2, 2, 1, 2, 2],
[2, 2, 2, 3, 1, 2, 1],
[2, 2, 2, 3, 2, 2, 2],
[2, 2, 2, 3, 3, 2, 2],
]
)
# we minus 1 because in Python, the first index is 0
return positions - 1, annotations - 1
def hierarchical_dawid_skene(positions: torch.Tensor, annotations: torch.Tensor) -> None:
"""
This model corresponds to the plate diagram in Figure 4 of reference [1].
"""
num_annotators = positions.unique().numel()
num_classes = annotations.unique().numel()
num_items, num_positions = annotations.shape
# print(f"{num_classes=}, {num_annotators=}, {num_items=}, {num_positions=}")
with pyro.plate("class", num_classes):
# NB: we define `beta` as the `logits` of `y` likelihood; but `logits` is
# invariant up to a constant, so we'll follow [1]: fix the last term of `beta`
# to 0 and only define hyperpriors for the first `num_classes - 1` terms.
zeta = pyro.sample("zeta", dist.Normal(0, 1).expand([num_classes - 1]).to_event(1))
omega = pyro.sample("Omega", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1))
with pyro.plate("annotator", num_annotators, dim=-2):
with pyro.plate("class_abilities", num_classes):
# non-centered parameterization
with reparam(config={"beta": LocScaleReparam(centered=0.0)}):
beta = pyro.sample("beta", dist.Normal(zeta, omega).to_event(1))
# NOTE: in this way it works
# beta_base = pyro.sample("beta_base", dist.Normal(0., 1.).expand([num_classes - 1]).to_event(1))
# beta = beta_base * omega + zeta
# pad 0 last dimension
beta = F.pad(beta, [0, 1] + [0, 0] * (beta.dim() - 1))
pi = pyro.sample("pi", dist.Dirichlet(torch.ones(num_classes)))
with pyro.plate("item", num_items, dim=-2):
c = pyro.sample("c", dist.Categorical(probs=pi))
with pyro.plate("position", num_positions):
logits = Vindex(beta)[positions, c, :]
pyro.sample("y", dist.Categorical(logits=logits), obs=annotations)
if __name__ == "__main__":
model = hierarchical_dawid_skene
data = get_data()
mcmc = MCMC(
NUTS(model),
warmup_steps=500,
num_samples=10_000,
num_chains=1,
)
mcmc.run(*data)
Details
In the first MCMC iteration with NUTS (following the numpyro example), I get the following shapes
num_classes=4, num_annotators=5, num_items=10, num_positions=7
c.shape=torch.Size([10, 1]), beta.shape=torch.Size([5, 4, 4])
In the second iteration, when c
is enumerated, I get the following shapes
num_classes=4, num_annotators=5, num_items=10, num_positions=7
c.shape=torch.Size([4, 1, 1]), beta.shape=torch.Size([4, 4])
Notably, beta
is changed from (5, 4, 4)
to (4, 4)
. This does not happen when I remove the reparametrization.
Watermark
- python 3.8
# Name Version Build Channel
numpyro 0.8.0 pypi_0 pypi
pyro-api 0.1.2 pypi_0 pypi
pyro-ppl 1.7.0 pypi_0 pypi
Hi @pietrolesci , it turns out that you would need to apply reparam outside of plate statements
with reparam(config={"beta": LocScaleReparam(centered=0.0)}):
with pyro.plate("annotator", num_annotators, dim=-2):
with pyro.plate("class_abilities", num_classes):
...
I think it is a better practice than putting reparam inside plate. But we will try to add a warning message to help users fix the model.
Hi @fehiepsi,
Thanks for your feedback. I'll try that out asap. Is this mentioned in the docs?
Also, is this true for numpyro as well? In case it is, maybe it makes sense to update the annotators examples accordingly in the doc.
Again, thanks for looking into this.
For numpyro it is currently not required. But to be safe (in case we change the behavior of init strategies in numpyro in the future), you can follow the above practice.
Thanks a lot @fehiepsi for your feedback!