WaveRNN icon indicating copy to clipboard operation
WaveRNN copied to clipboard

EMA trains faster

Open pfriesch opened this issue 5 years ago • 12 comments

Hi,

Cool project!

When I was trying VQVAE I found that using a moving average like described in the appendix trained a lot faster and gave better results! There is a zalandoresearch repo that has a open source example. It is a bit hard to parallelize though, since it does not depend on the optimizer to learn the embedding.

Cheers 👍

pfriesch avatar Mar 02 '19 14:03 pfriesch

Thanks for the information. It sounds like a good modification to try!

mkotha avatar Mar 03 '19 06:03 mkotha

Let's leave this open - someone (perhaps I) might want to implement this.

mkotha avatar Mar 04 '19 10:03 mkotha

I've tried on WaveNet+VQVAE with EMA, and it seems that using EMA only onVQ space can get a reasonable result.

Here is my result.zip, not as good as the official demo, though.

The implementation is just follow https://discuss.pytorch.org/t/how-to-apply-exponential-moving-average-decay-for-variables/10856/3 , glad if you find it helps.

mazzzystar avatar Mar 18 '19 02:03 mazzzystar

If I am not mistaken, the linked example is a regularization scheme for the model parameters?! That is not the EMA to learn just the embeddings? Do you have an example of your implementation?

pfriesch avatar Mar 18 '19 12:03 pfriesch

What do you mean by the linked example is a regularization scheme for the model parameters?

I mean it's the final audio result of my implementation of VQ-VAE with EMA only on embedding.

mazzzystar avatar Mar 18 '19 12:03 mazzzystar

Well, the forum post you linked is not the EMA that is described in the VQVAE paper. Did you modify the implementation to follow the VQVAE paper?

pfriesch avatar Mar 18 '19 12:03 pfriesch

Actually I think they also use EMA only on embedding.

You can see the implementation of sonnet/vqvae.py, I don't really read their implementation of EMA.

The difference between VectorQuantizerEMA and VectorQuantizer is that this module uses exponential moving averages to update the embedding vectors instead of an auxiliary loss.

mazzzystar avatar Mar 18 '19 13:03 mazzzystar

I mean it is called both exponential moving average (EMA). The forum post which you posted and what they do in the VQVAE paper (or in the sonnet code) are fundamentally different things. Even if they are only applied to the embedding parameters. The forum post seems to be a regularization strategy (I am not sure what it aims to do exactly), while the VQVAE EMA is a different learning algorithm for the embeddings.

I am not sure if you tried the correct VQVAE EMA learning algorithm, since you wrote that you followed to forum post.

pfriesch avatar Mar 18 '19 14:03 pfriesch

Yes I follow the forum method of exponential moving average. Sorry I didn't really realize they are different things. To my knowledge EMA is a regularization strategy, so I did not check the EMA idea of VQVAE. Thanks for your reminder, I'll check the paper tomorrow.

mazzzystar avatar Mar 18 '19 14:03 mazzzystar

Cool 👍 Looking forward to the samples 😄

pfriesch avatar Mar 18 '19 14:03 pfriesch

I quickly scan the Appendix of the VQVAE paper.

It seems they use the method as below: Remove sg(z_ex), e) ** 2 from loss function below

L = loss_recon + sg(z_ex), e) ** 2 + 0.25 * (sg(z_qx), e))

And update the embedding e as below:

  • Suppose currently the embedding center is e_i, there are N encoder outputs closet to it.
  • Then e_i will update itself with EMA.

What's your idea on that ?

mazzzystar avatar Mar 19 '19 02:03 mazzzystar

Yes, sounds correct. You can look at this notebook for a working example.

pfriesch avatar Mar 19 '19 09:03 pfriesch