disentangling-vae
disentangling-vae copied to clipboard
Why is tc_loss in bTCVAE negative?
https://github.com/YannDubs/disentangling-vae/blob/535bbd2e9aeb5a200663a4f82f1d34e084c4ba8d/results/btcvae_dsprites/train_losses.log#L5
https://github.com/rtqichen/beta-tcvae/ calculates logqz_prodmarginals = (logsumexp(_logqz, dim=1, keepdim=False) - math.log(batch_size * dataset_size)).sum(1) logqz = (logsumexp(_logqz.sum(2), dim=1, keepdim=False) - math.log(batch_size * dataset_size)) in case of # minibatch weighted sampling
and in case of # minibatch stratified sampling, they do logiw_matrix = Variable(self._log_importance_weight_matrix(batch_size, dataset_size).type_as(_logqz.data)) logqz = logsumexp(logiw_matrix + _logqz.sum(2), dim=1, keepdim=False) logqz_prodmarginals = logsumexp(logiw_matrix.view(batch_size, batch_size, 1) + _logqz, dim=1, keepdim=False).sum(1)
so in this codebase, shouldn't we also do (in case of NOT is_mss)
log_qz = (torch.logsumexp(mat_log_qz.sum(2), dim=1, keepdim=False)-math.log(batch_size*n_data))
log_prod_qzi = (torch.logsumexp(mat_log_qz, dim=1, keepdim=False)-math.log(batch_size*n_data)).sum(1)
and in case of (is_mss)
log_iw_mat = log_importance_weight_matrix(batch_size, n_data).to(latent_sample.device)
log_qz = torch.logsumexp(log_iw_mat + mat_log_qz.sum(2), dim=1, keepdim=False)
log_prod_qzi = torch.logsumexp(log_iw_mat.view(batch_size,batch_size,1)+mat_log_qz, dim=1, keepdim=False).sum(1)
Thanks @UserName-AnkitSisodia! I think you might be right (I am taking a sum instead of marginalizing in the log space), but It's been a long time so I'll have to double-check this w-e.
Did you test it with these changes?
Using some random matrices (code attached temp.txt temp.txt
), I used your code as well as Ricky Chen's code to compare what is happening.
I found
MWS log_qz != logqz_ricky log_prod_qzi != logqz_prodmarginals_ricky
MSS logqz_prodmarginals_ricky_mss == log_prod_qzi_mss logqz_ricky_mss != log_qz_mss
So, when I use your code with is_mss=true, then I get -ve tc_loss and with is_mss=false, I get -ve mi_loss and -ve tc_loss. I ran it on dsprites dataset with batchsize 128.
Then I changed the _get_log_pz_qz_prodzi_qzCx function in your code to make it similar to Ricky Chen's code.
batch_size, hidden_dim = latent_sample.shape
# calculate log q(z|x)
log_q_zCx = log_density_gaussian(latent_sample, *latent_dist).sum(dim=1)
# calculate log p(z)
# mean and log var is 0
zeros = torch.zeros_like(latent_sample)
log_pz = log_density_gaussian(latent_sample, zeros, zeros).sum(1)
mat_log_qz = matrix_log_density_gaussian(latent_sample, *latent_dist)
log_qz = (torch.logsumexp(mat_log_qz.sum(2), dim=1, keepdim=False)-math.log(batch_size * n_data)) ## Ankit - modified
log_prod_qzi = (torch.logsumexp(mat_log_qz, dim=1, keepdim=False)-math.log(batch_size * n_data)).sum(1) ## Ankit - modified
# is_mss=False
if is_mss: ## Ankit - modified
# use stratification ## Ankit - modifiede
log_iw_mat = log_importance_weight_matrix(batch_size, n_data).to(latent_sample.device) ## Ankit - modified
log_qz = torch.logsumexp(log_iw_mat + mat_log_qz.sum(2), dim=1, keepdim=False) ## Ankit - modified
log_prod_qzi = torch.logsumexp(log_iw_mat.view(batch_size,batch_size,1)+mat_log_qz, dim=1, keepdim=False).sum(1) ## Ankit - modified
return log_pz, log_qz, log_prod_qzi, log_q_zCx
Then I get +ve losses for everything when is_mss=True but then I get -ve dw_kl_loss term.
Awesome thanks for checking. Few comments:
1/ What do you mean by "+ve" and "-ve" ? What is ve ?
2/ Looking back at it it seems that I actually had the correct code and then incorporated the problem it in a late night push ( #43 )
Here's what I had before my changes:
def _get_log_pz_qz_prodzi_qzCx(latent_sample, latent_dist,n_data, is_mss=False):
batch_size, hidden_dim = latent_sample.shape
# calculate log q(z|x)
log_q_zCx = log_density_gaussian(latent_sample, *latent_dist).sum(dim=1)
# calculate log p(z)
# mean and log var is 0
zeros = torch.zeros_like(latent_sample)
log_pz = log_density_gaussian(latent_sample, zeros, zeros).sum(1)
if not self.is_mss:
log_qz, log_prod_qzi = _minibatch_weighted_sampling(latent_dist,
latent_sample,
n_data)
else:
log_qz, log_prod_qzi = _minibatch_stratified_sampling(latent_dist,
latent_sample,
n_data)
return log_pz, log_qz, log_prod_qzi, log_q_zCx
def _minibatch_weighted_sampling(latent_dist, latent_sample, data_size):
"""
Estimates log q(z) and the log (product of marginals of q(z_j)) with minibatch
weighted sampling.
Parameters
----------
latent_dist : tuple of torch.tensor
sufficient statistics of the latent dimension. E.g. for gaussian
(mean, log_var) each of shape : (batch_size, latent_dim).
latent_sample: torch.Tensor
sample from the latent dimension using the reparameterisation trick
shape : (batch_size, latent_dim).
data_size : int
Number of data in the training set
References
-----------
[1] Chen, Tian Qi, et al. "Isolating sources of disentanglement in variational
autoencoders." Advances in Neural Information Processing Systems. 2018.
"""
batch_size = latent_sample.size(0)
mat_log_qz = matrix_log_density_gaussian(latent_sample, *latent_dist)
log_prod_qzi = (torch.logsumexp(mat_log_qz, dim=1, keepdim=False) -
math.log(batch_size * data_size)).sum(dim=1)
log_qz = torch.logsumexp(mat_log_qz.sum(2), dim=1, keepdim=False
) - math.log(batch_size * data_size)
return log_qz, log_prod_qzi
def _minibatch_stratified_sampling(latent_dist, latent_sample, data_size):
"""
Estimates log q(z) and the log (product of marginals of q(z_j)) with minibatch
stratified sampling.
Parameters
-----------
latent_dist : tuple of torch.tensor
sufficient statistics of the latent dimension. E.g. for gaussian
(mean, log_var) each of shape : (batch_size, latent_dim).
latent_sample: torch.Tensor
sample from the latent dimension using the reparameterisation trick
shape : (batch_size, latent_dim).
data_size : int
Number of data in the training set
References
-----------
[1] Chen, Tian Qi, et al. "Isolating sources of disentanglement in variational
autoencoders." Advances in Neural Information Processing Systems. 2018.
"""
batch_size = latent_sample.size(0)
mat_log_qz = matrix_log_density_gaussian(latent_sample, *latent_dist)
log_iw_mat = log_importance_weight_matrix(batch_size, data_size).to(latent_sample.device)
log_qz = torch.logsumexp(log_iw_mat + mat_log_qz.sum(2), dim=1, keepdim=False)
log_prod_qzi = torch.logsumexp(log_iw_mat.view(batch_size, batch_size, 1) +
mat_log_qz, dim=1, keepdim=False).sum(1)
return log_qz, log_prod_qzi
which is (I believe) exactly what you tested.
-
Does it also work for
is_mss =False
? -
Just to be sure I understand, are you saying that with MSS this makes
dw_kl_loss
become negative ? -
did you see any impact on the qualitative samples when training a model that way ?
Yes, this makes the code exactly same. Once these changes are made, I get negative dw_kl_loss term in case of _minibatch_weighted_sampling. For _minibatch_stratified_sampling, I am getting all loss terms as positive. I tested on dsprites.
and qualitatively do you see any differences?
I didn't test that yet. I was just trying to see from the math/code where am I getting the error.
Has this issue been solved ? Training on dSprites, I also get negative tc loss
I also got the negative loss with the DSprites data
tc loss