numpyro
numpyro copied to clipboard
`TraceEnum_ELBO`: Subsample local variables that depend on a global model-enumerated variable
One of the features not supported by TraceEnum_ELBO is that you cannot subsample a local variable when it depends on a global variable that is enumerated in the model because it requires a common scale:
@config_enumerate
def model(data):
# Global variables.
locs = jnp.tensor([1., 10.])
a = pyro.sample('a', dist.Categorical(jnp.ones(2)))
with pyro.plate('data', len(data), subsample_size=2) as ind: # cannot subsample here
# Local variables.
pyro.sample('b', dist.Normal(locs[a], 1.), obs=data[ind])
def guide(data):
pass
This has been asked on the forum as well: https://forum.pyro.ai/t/enumeration-and-subsampling-expected-all-enumerated-sample-sites-to-share-common-poutine-scale/4938
Proposed solution here is to scale log factors as follows ($N$ - total size, $M$ - subsample size): $\log \sum_a p(a) {\prod_i}^{N} p(b_i | a) \approx \frac{N}{M}\log \sum_a p(a) {\prod_i}^{M} p(b_i | a)$
Expectation of the left hand side: $\mathbb{E} [ \log \sum_a p(a) {\prod_i}^{N} p(b_i | a) ] = \mathbb{E} [ \log {\prod_i}^{N} \sum_a p(a) p(b_i | a) ]= \mathbb{E} [ \log {\prod_i}^{N} p(b_i) ]$ $= \mathbb{E} [{\sum_i}^N \log p(b_i) ] = {\sum_i}^N \mathbb{E} [ \log p(b_i) ]$ $= N \mathbb{E} [ \log p(b_i) ]$
Expectation of the right hand side: $\mathbb{E} [ \frac{N}{M} \log \sum_a p(a) {\prod_i}^{M} p(b_i | a) ] = \frac{N}{M} \mathbb{E} [ \log {\prod_i}^{M} \sum_a p(a) p(b_i | a) ] = \frac{N}{M} \mathbb{E} [ \log {\prod_i}^{M} p(b_i) ]$ $= \frac{N}{M} \mathbb{E} [{\sum_i}^M \log p(b_i) ] = \frac{N}{M} {\sum_i}^M \mathbb{E} [ \log p(b_i) ]$ $= N \mathbb{E} [ \log p(b_i) ]$
Hi @ordabayevy, I don't understand how you can move prod and sum around. In particular, I'm not sure if your first equation makes sense: $\mathbb{E} [ \log \sum_a p(a) {\prod_i}^{N} p(b_i | a) ] = \mathbb{E} [ \log {\prod_i}^{N} \sum_a p(a) p(b_i | a) ]$ - could you clarify?
I think you are right @fehiepsi . Let me think more about this.
So the actual equation should be (same in the code): $\log \sum_a p(a) {\prod_i}^{N} p(b_i | a) \approx \log \sum_a p(a) \left ( {\prod_{i \in I_M}} p(b_i | a) \right ) ^ {\frac{N}{M}}$
This seems intuitive to me - subsample within a plate and then scale the product before summing it up. I did some tests and it seems to be unbiased. However, I can't figure out how to prove unbiasedness mathematically.
Code I used to check unbiasedness:
import torch
import torch.distributions as dist
import matplotlib.pyplot as plt
a = torch.tensor([0, 1])
logits_a = torch.log(torch.tensor([0.3, 0.7]))
# values are sampled from N(0, 1) and N(1, 1)
b = torch.rand(1000)
b[500:] += 1
d = dist.Normal(a, 1)
log_b = d.log_prob(b.reshape(-1,1))
expected = torch.logsumexp(log_b.sum(0) + logits_a, 0)
results = []
for _ in range(50000):
idx = torch.randperm(1000)[:100] # subsample 100 samples
scale = 10 # 1000 / 100
results.append(torch.logsumexp(d.log_prob(b[idx].reshape(-1,1)).sum(0) * scale + logits_a, 0))
print(expected)
print(torch.mean(torch.tensor(results)))
plt.plot(results)
plt.hlines(expected, 0, 50000, "C1")
plt.show()
>>> tensor(-1087.9426)
>>> tensor(-1087.8069)
I think it's easier to see the issue if we use a smaller number of data (e.g. just 2). Assume we are using subsample to estimate $\log \sum_a p(a) p(x|a)p(y|a)=\log \sqrt{\sum_{a,b} p(a)p(b)p(x|a)p(x|b)p(y|a)p(y|b)}$ - we have $0.5 * (\log \sum_a p(a) p(x|a)^2 + \log \sum_a p(a) p(y|a)^2) = \log \sqrt{\sum_{a,b} p(a)p(b) p(x|a)^2 p(y|b)^2}$ It seems clear to me that two terms $\sum_{a,b} p(a)p(b)p(x|a)p(x|b)p(y|a)p(y|b)$ and $\sum_{a,b} p(a)p(b) p(x|a)^2 p(y|b)^2$ are not equal.