Invertible-Image-Rescaling icon indicating copy to clipboard operation
Invertible-Image-Rescaling copied to clipboard

Problem about loss_ce

Open MC-E opened this issue 3 years ago • 9 comments

Thanks for your novel work! But I'm a little confused about loss_ce: l_forw_ce = self.train_opt['lambda_ce_forw'] * torch.sum(z**2) / z.shape[0]. I want to know why this loss function can restrain z following the Gaussian distribution. Looking forward to your reply!

MC-E avatar Sep 10 '20 07:09 MC-E

Hi, the strict distribution matching is realized by the JS divergence on X. As said in the paper, because of the invertibility, distributions match on X if and only if (y, z) follows the joint distribution of (f^y(q(x)), p(z)), which means z follows the Gaussian distribution and z is independent from y. In practice, we introduce a pre-training stage for stable training. In this stage, we use cross-entropy as a weaker surrogate objective which pushes the density of z towards p(z), but distributions may not strictly match in principle.

pkuxmq avatar Sep 10 '20 09:09 pkuxmq

Thanks for your reply! Does it mean the l_forw_ce = self.train_opt['lambda_ce_forw'] * torch.sum(z**2) / z.shape[0] is only an assistant component rather than a strict constraint of gaussian distributions?

MC-E avatar Sep 10 '20 09:09 MC-E

Thanks for your reply! Does it mean the l_forw_ce = self.train_opt['lambda_ce_forw'] * torch.sum(z**2) / z.shape[0] is only an assistant component rather than a strict constraint of gaussian distributions?

MC-E avatar Sep 10 '20 09:09 MC-E

Thanks for your reply! Does it mean the l_forw_ce = self.train_opt['lambda_ce_forw'] * torch.sum(z**2) / z.shape[0] is only an assistant component rather than a strict constraint of gaussian distributions?

MC-E avatar Sep 10 '20 09:09 MC-E

Thanks for your reply! Does it mean the l_forw_ce = self.train_opt['lambda_ce_forw'] * torch.sum(z**2) / z.shape[0] is only an assistant component rather than a strict constraint of gaussian distributions?

MC-E avatar Sep 10 '20 09:09 MC-E

Thanks for your reply! Does it mean the l_forw_ce = self.train_opt['lambda_ce_forw'] * torch.sum(z**2) / z.shape[0] is only an assistant component rather than a strict constraint of gaussian distributions?

MC-E avatar Sep 10 '20 09:09 MC-E

Yes, and we have ablation experiments on it in the paper.

pkuxmq avatar Sep 10 '20 09:09 pkuxmq

@pkuxmq So can we use Gaussian distributions' log_prob to calculate the l_forw_ce? as I think the l_forw_ce loss works for constraining Gaussian distribution

codyshen0000 avatar Oct 19 '20 12:10 codyshen0000

I think it might be also related to this answer? Why is regularization interpreted as a gaussian prior on my weights? Minimizing L2 loss of z has the probabilistic interpretation of assuming z is drawn from a normal distribution (mean=0, std=1), thus minimizing L2_norm(z) = maximizing likelihood of normal distribution N(z; mean=0, std=1). Is that correct?

howardyclo avatar Jan 16 '21 06:01 howardyclo