AGE icon indicating copy to clipboard operation
AGE copied to clipboard

is KLD calculation correct?

Open victor-shepardson opened this issue 8 years ago • 5 comments
trafficstars

https://github.com/DmitryUlyanov/AGE/blob/0915760d0506cbb3a371aece103745e1ed4806ee/src/losses.py#L38-L41

KLD appears to use variance in place of standard deviation. utils.var() computes variance as squared distance from mean. Then it's squared again in the KLN01Loss module. Should it be (in the default 'qp' direction):

t1 = samples_var + samples_mean.pow(2)
t2 = -samples_var.log()

KL = (t1 + t2 - 1).mean()/2

?

(Additionally, the paper gives the KLD as a sum but here it's a mean, changing the meaning of the hyperparameters weighting the reconstruction losses)

victor-shepardson avatar Sep 30 '17 00:09 victor-shepardson

Hi, yes, it looks like a mistake. Thanks for spotting it.

I will also change sum to mean in the paper, thank again!

Best, Dmitry

DmitryUlyanov avatar Sep 30 '17 17:09 DmitryUlyanov

I will fix it in in several days, when I will have time to make sure everything still works.

DmitryUlyanov avatar Sep 30 '17 17:09 DmitryUlyanov

No problem! I think the reconstruction losses in the paper are similarly given as norms, where they are also means in the code. And the latent space loss is said to be L2 but appears to really be cosine.

victor-shepardson avatar Sep 30 '17 23:09 victor-shepardson

The encoder transforms all output vectors to have a norm of 1 (i.e. mapping to a unit sphere). But if this is the case, a batch of such vectors cannot reach the unit Gaussian as demanded by the KL loss function, even when perfectly distributed around the sphere.

Have I missed something, or shouldn't the loss function be calculated with the standard deviation of the transformed vectors, rather than 1?

jshanna100 avatar Sep 19 '18 08:09 jshanna100

KL divergence in fact will not be zero in the perfect case, but when KL is minimal Q ~ uniform on sphere, that is what we want.

DmitryUlyanov avatar Sep 19 '18 15:09 DmitryUlyanov