pyro icon indicating copy to clipboard operation
pyro copied to clipboard

[FR] Enumeration support for `Independent` distributions with `reinterpreted_batch_ndims=1`

Open ordabayevy opened this issue 3 years ago • 1 comments

This is a FR for enumeration support for Independent distributions with reinterpreted_batch_ndims=1, e.g.:

>>> d = dist.Categorical(torch.ones(2, 3)).to_event(1)
>>> x = d.enumerate_support()
>>> print(x)
tensor([[[0., 0.],
         [0., 1.],
         [0., 2.]],

        [[1., 0.],
         [1., 1.],
         [1., 2.]],

        [[2., 0.],
         [2., 1.],
         [2., 2.]]])
>>> x.shape
torch.Size([3, 3, 2])

Implementation of .enumerate_support should probably be done in Pytorch but I'm curious if this kind of enumeration can be supported by TraceEnum_ELBO?

Use case

I have for loops in my model that I think could be vectorized if this functionality was supported:

h, m = {}, {}
for i in range(2):
    h[i] = pyro.sample(f"h_{i}", dist.Normal(0, 1))
    m[i] = pyro.sample(f"x_{i}", dist.Categorical(torch.ones(3)), infer={"enumerate": "parallel"})
loc = m[0] * h[0] + m[1] * h[1]
pyro.sample("y", dist.Normal(loc, 1), obs=data)

want to write it as:

h = pyro.sample("h", dist.Normal(torch.zeros(2), 1).to_event(1))
m = pyro.sample("x", dist.Categorical(torch.ones(2, 3)).to_event(1), infer={"enumerate": "parallel"})
loc = (m * h).sum(-1)
pyro.sample("y", dist.Normal(loc, 1), obs=data)

ordabayevy avatar Nov 30 '21 20:11 ordabayevy

This is probably better done upstream in PyTorch, but you can add a patch in Pyro if you want. Note we intentionally left this unimplemented because it leads to exponential growth which often leads to OOM and bad user experience.

fritzo avatar Dec 01 '21 01:12 fritzo