vector-quantize-pytorch icon indicating copy to clipboard operation
vector-quantize-pytorch copied to clipboard

EMA update on CosineCodebook

Open roomo7time opened this issue 2 years ago • 7 comments

The original VIT-VQGAN paper does not seem to use EMA update for codebook learning since their codebook is unit-normalized vectors.

Particularly, to my understanding, EMA update does not quite make sense when the encoder outputs and codebook vectors are unit-normalized ones.

What's your take on this? Should we NOT use EMA update with CosineCodebook?

roomo7time avatar Sep 27 '22 07:09 roomo7time

Would you like to explain why ema does not work for the unit-normalized codebook?

pengzhangzhi avatar Oct 29 '22 08:10 pengzhangzhi

I found when using EMA for cosine code book, the l2-norm of the input to the vq module would grow gradually, from 22 -> 20000, leading to growing training loss. Has anyone met this problem?

Saltychtao avatar Nov 03 '22 12:11 Saltychtao

I found when using EMA for cosine code book, the l2-norm of the input to the vq module would grow gradually, from 22 -> 20000, leading to growing training loss. Has anyone met this problem?

In case anyone else has this problem, I add a layernorm layer after the vq_in projection, and the growing norm problem is largely solved.

Saltychtao avatar Nov 23 '22 08:11 Saltychtao

@Saltychtao I also encounter a similar issue. Does vq_in refer to VectorQuantize.project_in?

jzhang38 avatar Mar 09 '23 04:03 jzhang38

@Saltychtao I also encounter a similar issue. Does vq_in refer to VectorQuantize.project_in?

Yes.

Saltychtao avatar Apr 18 '23 01:04 Saltychtao

I found when using EMA for cosine code book, the l2-norm of the input to the vq module would grow gradually, from 22 -> 20000, leading to growing training loss. Has anyone met this problem?

In case anyone else has this problem, I add a layernorm layer after the vq_in projection, and the growing norm problem is largely solved.

@Saltychtao Hi, just want to make sure that the current vesion of implementation here seems to put one normalization (l2norm) after the project_in. I also encounter the training loss explosion issue somehow at current version

santisy avatar May 13 '24 19:05 santisy

@santisy want to try turning this on (following @Saltychtao 's solution)

let me know if it helps

lucidrains avatar May 13 '24 20:05 lucidrains