vq-vae-2-pytorch
vq-vae-2-pytorch copied to clipboard
About the loss of VQVAE2 network
Hi,
in the paper, total loss consists of 3 parts as follows:
However, in the codes it seems that this loss is different (as follows)
loss = recon_loss + latent_loss_weight * latent_loss
And honestly I only know that 'latent_loss' stems from the return value 'diff' of the 'Quantize.forward' in vqvae.py, which is
diff = (quantize.detach() - input).pow(2).mean()
Although this train_vqvae.py did work out on my dataset and achieved quite awesome results (thanks for your sharing by the way), I actually cannot read the process from the encoded outcome E(x) to the quantized outcome ek, nor that of the latent loss.
So I am wandering that is there any explanation?
Thanks.
Oh one more question, is the parameter 'decay' of 'Quantize.init' in vqvae.py used to deal with the following equation from the paper?
And I understand it may get knotty to explain the specifics of this 'Quantize' class, so I'd also appreciated it if I'm assured that these codes fulfill the right work in the paper.
latent_loss_weight * (quantize.detach() - input).pow(2).mean() is corresponds to \beta ||sg[e] - E(x)||^2_2, and \\sg[E(x)] - e||^2_2 is replaced by ema updates, as you can find from the paper. And ema updates is corresponds to the equations you have cited. (decay would corresponds to \gamma.)
Codes in the Quantize class is vectorized implementation which is ported from the official repository. It is somewhat tricky, but you can match it with the formula in the paper.
@rosinality Thanks for your reply, it helps a lot!
@rosinality would you think its a good idea to expose the latent loss's weight as a parameter - or d'you think 0.25 should work on most other datasets too? 🤔