numpyro icon indicating copy to clipboard operation
numpyro copied to clipboard

`TraceEnum_ELBO`: Subsample local variables that depend on a global model-enumerated variable

Open ordabayevy opened this issue 2 years ago • 5 comments

One of the features not supported by TraceEnum_ELBO is that you cannot subsample a local variable when it depends on a global variable that is enumerated in the model because it requires a common scale:

@config_enumerate
def model(data):
    # Global variables.
    locs = jnp.tensor([1., 10.])
    a = pyro.sample('a', dist.Categorical(jnp.ones(2)))
    with pyro.plate('data', len(data), subsample_size=2) as ind:  # cannot subsample here
        # Local variables.
        pyro.sample('b', dist.Normal(locs[a], 1.), obs=data[ind])

def guide(data):
    pass

This has been asked on the forum as well: https://forum.pyro.ai/t/enumeration-and-subsampling-expected-all-enumerated-sample-sites-to-share-common-poutine-scale/4938

Proposed solution here is to scale log factors as follows ($N$ - total size, $M$ - subsample size): $\log \sum_a p(a) {\prod_i}^{N} p(b_i | a) \approx \frac{N}{M}\log \sum_a p(a) {\prod_i}^{M} p(b_i | a)$

Expectation of the left hand side: $\mathbb{E} [ \log \sum_a p(a) {\prod_i}^{N} p(b_i | a) ] = \mathbb{E} [ \log {\prod_i}^{N} \sum_a p(a) p(b_i | a) ]= \mathbb{E} [ \log {\prod_i}^{N} p(b_i) ]$ $= \mathbb{E} [{\sum_i}^N \log p(b_i) ] = {\sum_i}^N \mathbb{E} [ \log p(b_i) ]$ $= N \mathbb{E} [ \log p(b_i) ]$

Expectation of the right hand side: $\mathbb{E} [ \frac{N}{M} \log \sum_a p(a) {\prod_i}^{M} p(b_i | a) ] = \frac{N}{M} \mathbb{E} [ \log {\prod_i}^{M} \sum_a p(a) p(b_i | a) ] = \frac{N}{M} \mathbb{E} [ \log {\prod_i}^{M} p(b_i) ]$ $= \frac{N}{M} \mathbb{E} [{\sum_i}^M \log p(b_i) ] = \frac{N}{M} {\sum_i}^M \mathbb{E} [ \log p(b_i) ]$ $= N \mathbb{E} [ \log p(b_i) ]$

ordabayevy avatar Apr 09 '23 03:04 ordabayevy

Hi @ordabayevy, I don't understand how you can move prod and sum around. In particular, I'm not sure if your first equation makes sense: $\mathbb{E} [ \log \sum_a p(a) {\prod_i}^{N} p(b_i | a) ] = \mathbb{E} [ \log {\prod_i}^{N} \sum_a p(a) p(b_i | a) ]$ - could you clarify?

fehiepsi avatar Sep 01 '23 03:09 fehiepsi

I think you are right @fehiepsi . Let me think more about this.

ordabayevy avatar Sep 05 '23 02:09 ordabayevy

So the actual equation should be (same in the code): $\log \sum_a p(a) {\prod_i}^{N} p(b_i | a) \approx \log \sum_a p(a) \left ( {\prod_{i \in I_M}} p(b_i | a) \right ) ^ {\frac{N}{M}}$

This seems intuitive to me - subsample within a plate and then scale the product before summing it up. I did some tests and it seems to be unbiased. However, I can't figure out how to prove unbiasedness mathematically.

ordabayevy avatar Sep 06 '23 04:09 ordabayevy

Code I used to check unbiasedness:

import torch
import torch.distributions as dist
import matplotlib.pyplot as plt

a = torch.tensor([0, 1])
logits_a = torch.log(torch.tensor([0.3, 0.7]))

# values are sampled from N(0, 1) and N(1, 1)
b = torch.rand(1000)
b[500:] += 1

d = dist.Normal(a, 1)
log_b = d.log_prob(b.reshape(-1,1))

expected = torch.logsumexp(log_b.sum(0) + logits_a, 0)

results = []
for _ in range(50000):
    idx = torch.randperm(1000)[:100] # subsample 100 samples
    scale = 10  # 1000 / 100
    results.append(torch.logsumexp(d.log_prob(b[idx].reshape(-1,1)).sum(0) * scale + logits_a, 0))

print(expected)
print(torch.mean(torch.tensor(results)))

plt.plot(results)
plt.hlines(expected, 0, 50000, "C1")
plt.show()

>>> tensor(-1087.9426)
>>> tensor(-1087.8069)

image

ordabayevy avatar Sep 06 '23 04:09 ordabayevy

I think it's easier to see the issue if we use a smaller number of data (e.g. just 2). Assume we are using subsample to estimate $\log \sum_a p(a) p(x|a)p(y|a)=\log \sqrt{\sum_{a,b} p(a)p(b)p(x|a)p(x|b)p(y|a)p(y|b)}$ - we have $0.5 * (\log \sum_a p(a) p(x|a)^2 + \log \sum_a p(a) p(y|a)^2) = \log \sqrt{\sum_{a,b} p(a)p(b) p(x|a)^2 p(y|b)^2}$ It seems clear to me that two terms $\sum_{a,b} p(a)p(b)p(x|a)p(x|b)p(y|a)p(y|b)$ and $\sum_{a,b} p(a)p(b) p(x|a)^2 p(y|b)^2$ are not equal.

fehiepsi avatar Sep 07 '23 09:09 fehiepsi