pyro
pyro copied to clipboard
[FR] Enumeration support for `Independent` distributions with `reinterpreted_batch_ndims=1`
This is a FR for enumeration support for Independent distributions with reinterpreted_batch_ndims=1, e.g.:
>>> d = dist.Categorical(torch.ones(2, 3)).to_event(1)
>>> x = d.enumerate_support()
>>> print(x)
tensor([[[0., 0.],
[0., 1.],
[0., 2.]],
[[1., 0.],
[1., 1.],
[1., 2.]],
[[2., 0.],
[2., 1.],
[2., 2.]]])
>>> x.shape
torch.Size([3, 3, 2])
Implementation of .enumerate_support should probably be done in Pytorch but I'm curious if this kind of enumeration can be supported by TraceEnum_ELBO?
Use case
I have for loops in my model that I think could be vectorized if this functionality was supported:
h, m = {}, {}
for i in range(2):
h[i] = pyro.sample(f"h_{i}", dist.Normal(0, 1))
m[i] = pyro.sample(f"x_{i}", dist.Categorical(torch.ones(3)), infer={"enumerate": "parallel"})
loc = m[0] * h[0] + m[1] * h[1]
pyro.sample("y", dist.Normal(loc, 1), obs=data)
want to write it as:
h = pyro.sample("h", dist.Normal(torch.zeros(2), 1).to_event(1))
m = pyro.sample("x", dist.Categorical(torch.ones(2, 3)).to_event(1), infer={"enumerate": "parallel"})
loc = (m * h).sum(-1)
pyro.sample("y", dist.Normal(loc, 1), obs=data)
This is probably better done upstream in PyTorch, but you can add a patch in Pyro if you want. Note we intentionally left this unimplemented because it leads to exponential growth which often leads to OOM and bad user experience.