vector-quantize-pytorch
vector-quantize-pytorch copied to clipboard
Missing feature to reproduce SoundStream's Residual Vector Quantizer
Hi, Thanks for this cool work! I couldn't help but notice that a few features used to improve the usage of the codebooks were missing to be an exact implementation of the work done in the SoundStream article.
- "First, instead of using a random initialization for the codebook vectors, we run the k-means algorithm on the first training batch and use the learned centroids as initialization"
- "Second, as proposed in [34], when a codebook vector has not been assigned any input frame for several batches, we replace it with an input frame randomly sampled within the current batch." I'm currently working on an implementation of this work, I'll use and adapt your code for this purpose but thought you might want to know about it. I'll keep you posted :)
@wesbz hey Wes! thanks for raising this issue! I can finish the first bullet point tonight, and as for the second, it seems tied to the other codebook revival issue, and i can tackle that all at once next week
@lucidrains thanks for your response! I have completed the first bullet point, but I'd be more than glad to see your implementation of this!
I completed (I think) both points, you can see it here
@wesbz oh whoops! i just finished number one too https://github.com/lucidrains/vector-quantize-pytorch/commit/d28d851aa514b244120c2fd286a7bab380ccd127 do you want to review it and see if that works for your use-case?
@wesbz i'll take a look at your number 2 solution later!
@wesbz also, would you be interested in testing https://github.com/lucidrains/vector-quantize-pytorch/issues/2#issuecomment-937387129 if i were to build it?
@wesbz added my version of bullet point 2 here https://github.com/lucidrains/vector-quantize-pytorch#expiring-stale-codes
Hi, thanks for notifying me! I'll try and use your version as it seems very well coded and to avoid useless code duplication for my implementation of SoundStream :joy: I'll let you know if I ever encounter something weird :wink: Thanks again
Also! Regarding your implementation of factorized codes and $l_2$-normalized codes, I don't think they're used in SoundStream so I wouldn't have the occasion to test them :sweat_smile:
Oh and by the way, something still missing (but I do not need it so I didn't mention it in my first comment) from the SoundStream article is the bitrate scalability of the RVQ with quantizer dropout. I can do a pull request for that later ;)
@wesbz ok sounds good! thank you for the PR! i realized there's an issue with the cluster size; we should probably be keeping track of the init state across all codes, and on the first EMA, just set the cluster size to the value instead of assuming 0 for the moving average
I'll check out the dropout later this evening as well!
Shouldn't the threshold also be based off the cluster size normalized? Wouldn't one batch that is smaller than usual throw things off?
Yes, I was also wondering about the first use of EMA. I don't know what is commonly done when working with this statistics. Regarding the threshold, what do you mean by "cluster size normalized"? In SoundStream, they use a decay of 0.99, so there not much scenarios where the EMA would fall below 2 except if the code has not been or barely used for quite a long moment.
If there are few vectors in a batch, and codebook size is large, then many codes may not be chosen. Say, 256(vectors) vs 1024(codebook size). Maybe set the threshold to (number of codes in a batch) / (batch size) or something like this is bettwr ? (I don't experiment it yet)
by the way, nice work ! 👍
the easy way out would be to have an extra hyperparameter that only expires when there are sufficient vectors, maybe min_batch_vectors_check_expiry or something (i'm really bad at naming)
In SoundStream, the encoder is built in such a way that 24kHz waveform are transformed into a 75Hz embedding. So you only need a batch of 16×1 second audio samples to already have to process more than 1024 vectors with your RVQ. And again, as long as you have a decay factor close to 1, you shouldn't encounter instability in the codebooks. I'll see what it looks like when I'll have it to work correctly :)
ah got it! so it is really only the init we have to worry about
@wesbz I think the kmeans init case will be fine now with https://github.com/lucidrains/vector-quantize-pytorch/commit/baf249e914fb083cfa7a385d389bcb4f315f9ef7 but I still need to solve it for the non-kmeans init, or whenever a code expires as well (not sure what to set the cluster size to)