vector-quantize-pytorch
vector-quantize-pytorch copied to clipboard
Building intuition about the latent quantizer
Hi!
I'd like to build a bit more intuition about the latent quantizer in order to do a hyperparameter sweep.
I'm trying to develop a codebook for a 1D signal, and I've put a simple encoder and decoder on either side of the latent quantizer (simple Resnet with a few layers) to expand it from 1 to 32 channels without changing the sample rate, meaning no stride in the CNN.
I'm using:
quantizer = LatentQuantize(
levels = [5, 5, 8],
dim = 32,
commitment_loss_weight=0.1,
quantization_loss_weight=0.1,
)
optimizier = Adam(
self.model.parameters(), lr=1e-5, weight_decay=1.0 # high weight decay as per the suggestion in the paper
)
After 100 epochs, it does a decent job reconstructing the signal. Red is the reconstruction, blue is the original.
Given that there's no time decimation and there's a channel expansion from 1 to 32, I would think that it'd do a better job with the signal. Other quantizers, like the VectorQuantizer, do a near-perfect job basically right away, although their latent spaces aren't super useful in a diffuser, which is why I'm experimenting with new quantizers.
I'm wondering what "knobs" can be tuned on the latent quantizer to get a better result. I still have tons of headroom on my A100, so if more levels or a higher dimension will help, I can try that, but I didn't want to mess around with the hyperparameters until I had a better sense of their tradeoffs. Thanks for any tips!