WaveRNN
WaveRNN copied to clipboard
EMA trains faster
Hi,
Cool project!
When I was trying VQVAE I found that using a moving average like described in the appendix trained a lot faster and gave better results! There is a zalandoresearch repo that has a open source example. It is a bit hard to parallelize though, since it does not depend on the optimizer to learn the embedding.
Cheers 👍
Thanks for the information. It sounds like a good modification to try!
Let's leave this open - someone (perhaps I) might want to implement this.
I've tried on WaveNet+VQVAE with EMA, and it seems that using EMA only onVQ
space can get a reasonable result.
Here is my result.zip, not as good as the official demo, though.
The implementation is just follow https://discuss.pytorch.org/t/how-to-apply-exponential-moving-average-decay-for-variables/10856/3 , glad if you find it helps.
If I am not mistaken, the linked example is a regularization scheme for the model parameters?! That is not the EMA to learn just the embeddings? Do you have an example of your implementation?
What do you mean by the linked example is a regularization scheme for the model parameters
?
I mean it's the final audio result of my implementation of VQ-VAE
with EMA
only on embedding.
Well, the forum post you linked is not the EMA that is described in the VQVAE paper. Did you modify the implementation to follow the VQVAE paper?
Actually I think they also use EMA
only on embedding.
You can see the implementation of sonnet/vqvae.py, I don't really read their implementation of EMA
.
The difference between VectorQuantizerEMA and VectorQuantizer is that this module uses exponential moving averages to update the embedding vectors instead of an auxiliary loss.
I mean it is called both exponential moving average (EMA). The forum post which you posted and what they do in the VQVAE paper (or in the sonnet code) are fundamentally different things. Even if they are only applied to the embedding parameters. The forum post seems to be a regularization strategy (I am not sure what it aims to do exactly), while the VQVAE EMA is a different learning algorithm for the embeddings.
I am not sure if you tried the correct VQVAE EMA learning algorithm, since you wrote that you followed to forum post.
Yes I follow the forum method of exponential moving average. Sorry I didn't really realize they are different things. To my knowledge EMA
is a regularization strategy, so I did not check the EMA
idea of VQVAE
. Thanks for your reminder, I'll check the paper tomorrow.
Cool 👍 Looking forward to the samples 😄
I quickly scan the Appendix of the VQVAE paper.
It seems they use the method as below:
Remove sg(z_ex), e) ** 2
from loss function below
L = loss_recon + sg(z_ex), e) ** 2 + 0.25 * (sg(z_qx), e))
And update the embedding e
as below:
- Suppose currently the embedding center is
e_i
, there areN
encoder outputs closet to it. - Then
e_i
will update itself with EMA.
What's your idea on that ?
Yes, sounds correct. You can look at this notebook for a working example.