vector-quantize-pytorch icon indicating copy to clipboard operation
vector-quantize-pytorch copied to clipboard

Missing feature to reproduce SoundStream's Residual Vector Quantizer

Open wesbz opened this issue 4 years ago • 19 comments
trafficstars

Hi, Thanks for this cool work! I couldn't help but notice that a few features used to improve the usage of the codebooks were missing to be an exact implementation of the work done in the SoundStream article.

  • "First, instead of using a random initialization for the codebook vectors, we run the k-means algorithm on the first training batch and use the learned centroids as initialization"
  • "Second, as proposed in [34], when a codebook vector has not been assigned any input frame for several batches, we replace it with an input frame randomly sampled within the current batch." I'm currently working on an implementation of this work, I'll use and adapt your code for this purpose but thought you might want to know about it. I'll keep you posted :)

wesbz avatar Oct 17 '21 10:10 wesbz

@wesbz hey Wes! thanks for raising this issue! I can finish the first bullet point tonight, and as for the second, it seems tied to the other codebook revival issue, and i can tackle that all at once next week

lucidrains avatar Oct 17 '21 19:10 lucidrains

@lucidrains thanks for your response! I have completed the first bullet point, but I'd be more than glad to see your implementation of this!

wesbz avatar Oct 17 '21 20:10 wesbz

I completed (I think) both points, you can see it here

wesbz avatar Oct 18 '21 00:10 wesbz

@wesbz oh whoops! i just finished number one too https://github.com/lucidrains/vector-quantize-pytorch/commit/d28d851aa514b244120c2fd286a7bab380ccd127 do you want to review it and see if that works for your use-case?

lucidrains avatar Oct 18 '21 17:10 lucidrains

@wesbz i'll take a look at your number 2 solution later!

lucidrains avatar Oct 18 '21 17:10 lucidrains

@wesbz also, would you be interested in testing https://github.com/lucidrains/vector-quantize-pytorch/issues/2#issuecomment-937387129 if i were to build it?

lucidrains avatar Oct 18 '21 17:10 lucidrains

@wesbz added my version of bullet point 2 here https://github.com/lucidrains/vector-quantize-pytorch#expiring-stale-codes

lucidrains avatar Oct 19 '21 05:10 lucidrains

Hi, thanks for notifying me! I'll try and use your version as it seems very well coded and to avoid useless code duplication for my implementation of SoundStream :joy: I'll let you know if I ever encounter something weird :wink: Thanks again

wesbz avatar Oct 19 '21 21:10 wesbz

Also! Regarding your implementation of factorized codes and $l_2$-normalized codes, I don't think they're used in SoundStream so I wouldn't have the occasion to test them :sweat_smile:

wesbz avatar Oct 20 '21 09:10 wesbz

Oh and by the way, something still missing (but I do not need it so I didn't mention it in my first comment) from the SoundStream article is the bitrate scalability of the RVQ with quantizer dropout. I can do a pull request for that later ;)

wesbz avatar Oct 20 '21 09:10 wesbz

@wesbz ok sounds good! thank you for the PR! i realized there's an issue with the cluster size; we should probably be keeping track of the init state across all codes, and on the first EMA, just set the cluster size to the value instead of assuming 0 for the moving average

I'll check out the dropout later this evening as well!

lucidrains avatar Oct 20 '21 19:10 lucidrains

Shouldn't the threshold also be based off the cluster size normalized? Wouldn't one batch that is smaller than usual throw things off?

lucidrains avatar Oct 21 '21 14:10 lucidrains

Yes, I was also wondering about the first use of EMA. I don't know what is commonly done when working with this statistics. Regarding the threshold, what do you mean by "cluster size normalized"? In SoundStream, they use a decay of 0.99, so there not much scenarios where the EMA would fall below 2 except if the code has not been or barely used for quite a long moment.

wesbz avatar Oct 21 '21 15:10 wesbz

If there are few vectors in a batch, and codebook size is large, then many codes may not be chosen. Say, 256(vectors) vs 1024(codebook size). Maybe set the threshold to (number of codes in a batch) / (batch size) or something like this is bettwr ? (I don't experiment it yet)

MagicGeek2 avatar Oct 21 '21 16:10 MagicGeek2

by the way, nice work ! 👍

MagicGeek2 avatar Oct 21 '21 16:10 MagicGeek2

the easy way out would be to have an extra hyperparameter that only expires when there are sufficient vectors, maybe min_batch_vectors_check_expiry or something (i'm really bad at naming)

lucidrains avatar Oct 21 '21 19:10 lucidrains

In SoundStream, the encoder is built in such a way that 24kHz waveform are transformed into a 75Hz embedding. So you only need a batch of 16×1 second audio samples to already have to process more than 1024 vectors with your RVQ. And again, as long as you have a decay factor close to 1, you shouldn't encounter instability in the codebooks. I'll see what it looks like when I'll have it to work correctly :)

wesbz avatar Oct 21 '21 22:10 wesbz

ah got it! so it is really only the init we have to worry about

lucidrains avatar Oct 22 '21 00:10 lucidrains

@wesbz I think the kmeans init case will be fine now with https://github.com/lucidrains/vector-quantize-pytorch/commit/baf249e914fb083cfa7a385d389bcb4f315f9ef7 but I still need to solve it for the non-kmeans init, or whenever a code expires as well (not sure what to set the cluster size to)

lucidrains avatar Oct 22 '21 01:10 lucidrains