vq-vae-2-pytorch icon indicating copy to clipboard operation
vq-vae-2-pytorch copied to clipboard

About the loss of VQVAE2 network

Open ZhanYangen opened this issue 4 years ago • 4 comments
trafficstars

Hi, in the paper, total loss consists of 3 parts as follows: image However, in the codes it seems that this loss is different (as follows) loss = recon_loss + latent_loss_weight * latent_loss And honestly I only know that 'latent_loss' stems from the return value 'diff' of the 'Quantize.forward' in vqvae.py, which is diff = (quantize.detach() - input).pow(2).mean() Although this train_vqvae.py did work out on my dataset and achieved quite awesome results (thanks for your sharing by the way), I actually cannot read the process from the encoded outcome E(x) to the quantized outcome ek, nor that of the latent loss. So I am wandering that is there any explanation? Thanks.

ZhanYangen avatar Jan 18 '21 15:01 ZhanYangen

Oh one more question, is the parameter 'decay' of 'Quantize.init' in vqvae.py used to deal with the following equation from the paper? image And I understand it may get knotty to explain the specifics of this 'Quantize' class, so I'd also appreciated it if I'm assured that these codes fulfill the right work in the paper.

ZhanYangen avatar Jan 18 '21 15:01 ZhanYangen

latent_loss_weight * (quantize.detach() - input).pow(2).mean() is corresponds to \beta ||sg[e] - E(x)||^2_2, and \\sg[E(x)] - e||^2_2 is replaced by ema updates, as you can find from the paper. And ema updates is corresponds to the equations you have cited. (decay would corresponds to \gamma.)

Codes in the Quantize class is vectorized implementation which is ported from the official repository. It is somewhat tricky, but you can match it with the formula in the paper.

rosinality avatar Jan 18 '21 16:01 rosinality

@rosinality Thanks for your reply, it helps a lot!

ZhanYangen avatar Jan 19 '21 02:01 ZhanYangen

@rosinality would you think its a good idea to expose the latent loss's weight as a parameter - or d'you think 0.25 should work on most other datasets too? 🤔

neel04 avatar Feb 05 '22 14:02 neel04