SCALE
SCALE copied to clipboard
Question about the elbo_SCALE
Hi,
SCALE is a very interesting and useful tool.
But I have a question about the calculation of ELBO.
Why gamma=q(c|x)=p(c|z) is used to calculate ELBO instead of q(z,c|x)? Does q(c|x)
equals to q(z,c|x)
in Eq(z,c|x)[log p(x|z) + log p(z|c) + log p(c) - log q(z|x) - log q(c|x)]
in loss.py?
def elbo_SCALE(recon_x, x, gamma, c_params, z_params, binary=True):
"""
L elbo(x) = Eq(z,c|x)[ log p(x|z) ] - KL(q(z,c|x)||p(z,c))
= Eq(z,c|x)[ log p(x|z) + log p(z|c) + log p(c) - log q(z|x) - log q(c|x) ]
"""
mu_c, var_c, pi = c_params; #print(mu_c.size(), var_c.size(), pi.size())
var_c += 1e-8
n_centroids = pi.size(1)
mu, logvar = z_params
mu_expand = mu.unsqueeze(2).expand(mu.size(0), mu.size(1), n_centroids)
logvar_expand = logvar.unsqueeze(2).expand(logvar.size(0), logvar.size(1), n_centroids)
# log p(x|z)
if binary:
likelihood = -binary_cross_entropy(recon_x, x) #;print(logvar_expand.size()) #, torch.exp(logvar_expand)/var_c)
else:
likelihood = -F.mse_loss(recon_x, x)
# log p(z|c)
logpzc = -0.5*torch.sum(gamma*torch.sum(math.log(2*math.pi) + \
torch.log(var_c) + \
torch.exp(logvar_expand)/var_c + \
(mu_expand-mu_c)**2/var_c, dim=1), dim=1)
# log p(c)
logpc = torch.sum(gamma*torch.log(pi), 1)
# log q(z|x) or q entropy
qentropy = -0.5*torch.sum(1+logvar+math.log(2*math.pi), 1)
# log q(c|x)
logqcx = torch.sum(gamma*torch.log(gamma), 1)
kld = -logpzc - logpc + qentropy + logqcx
return torch.sum(likelihood), torch.sum(kld)
Thanks!
Hi, thanks for your insterest in SCALE.
-
gamma=q(c|x)=p(c|z), gamma is an inference function that inferences the cluster (c) from the original data (x). However, the inference for c in the model is from the latent (z), because only z is connected with c in the model (x->z->c), thus, we replace q(c|x) with p(c|z), this can be also regarded as an approximation.
-
q(z, c|x) = q(c|x) * q(z|x), this is because c and z can be both directly inferenced from x in their actual relationship (z<-x->c, different from the modeled relationship in the SCALE, x->z->c), thus c and z are independent condition on the observed x.
I hope these could answer your question.
Thanks for your reply! But I am still a little confused. For example,
# log p(c)
logpc = torch.sum(gamma*torch.log(pi), 1)
I think it should be
Eq(z,c|x)[log p(c)] = \int q(z,c|x) log p(c) dx
Why you used gamma
instead of q(z,c|x)
in this calculation?