vector-quantize-pytorch
vector-quantize-pytorch copied to clipboard
Seeking clarifications regarding learnable codebook
trafficstars
Hi,
I am interested in learning codewords (not using EMA) that are L2-normalized and orthonormal with each other. To do so, I created the vector quantizer using the following configuration:
vq = VectorQuantize(
dim = 256,
codebook_size = 256,
use_cosine_sim = True,
orthogonal_reg_weight = 10,
orthogonal_reg_max_codes = 128,
orthogonal_reg_active_codes_only = False,
learnable_codebook=True,
ema_update=False
)
However, I noticed in the implementation at line 1071 that there is only a single term that enforces input embedding to push towards their corresponding quantized (codeword) embeddings. It does not include a second term that would enforce the other way round. Am I missing something here?
Also, if I create a vector quantizer that learns codebook using EMA with the following configuration:
vq = VectorQuantize(
dim = 256,
codebook_size = 256,
use_cosine_sim = True,
orthogonal_reg_weight = 10,
orthogonal_reg_max_codes = 128,
orthogonal_reg_active_codes_only = False,
learnable_codebook=False,
ema_update=True,
decay=0.8
)
Will it still learn codewords to ensure their orthonormalilty?