vector-quantize-pytorch
vector-quantize-pytorch copied to clipboard
No way of training the codebook
Hi! Could you please explain how the codebook vectors are updated if the codebook vectors are not required to be orthogonal?
embedtensors in both Euclidean and CosineSim codebooks are registered as buffers, so they can't be updated at all- There is no loss on the codebook vectors that moves them closer to the input
Am I missing something? It seems that right now there is no way of updating the codebook vectors without the orthogonal loss.
@RafailFridman they are updated through an exponential moving average of the cluster statistics during training shown here
It seems totally broken to me. Please see the following code and output
import torch
torch.manual_seed(3026076450)
x = torch.randn((1, 4, 2))
from vector_quantize_pytorch import VectorQuantize
vq = VectorQuantize(
dim = 2,
codebook_size = 4, # codebook size
decay = 0.1, # the exponential moving average decay, lower means the dictionary will change faster
commitment_weight = 0. # the weight on the commitment loss
)
print("codebook")
print(vq.codebook)
print(vq._codebook.embed_avg)
for _ in range(1000):
y, *_ = vq(x)
print("codebook")
print(vq.codebook)
print(vq._codebook.embed_avg)
print("x, quantized x")
print(x)
print(y)
codebook
tensor([[ 0.8001, -0.2969],
[-0.6842, -0.4783],
[-0.6681, 0.4355],
[ 0.1892, -0.0102]])
tensor([[[ 0.8001, -0.2969],
[-0.6842, -0.4783],
[-0.6681, 0.4355],
[ 0.1892, -0.0102]]])
codebook
tensor([[ 8.0007e+04, -2.9693e+04],
[-6.8416e+04, -4.7833e+04],
[-6.6815e+04, 4.3555e+04],
[ 4.7304e-02, -2.5412e-03]])
tensor([[[ 0.8001, -0.2969],
[-0.6842, -0.4783],
[-0.6681, 0.4355],
[ 0.1892, -0.0102]]])
x, quantized x
tensor([[[ 0.8369, 0.9166],
[ 0.1462, 0.1903],
[ 0.3105, -0.5587],
[ 0.1310, 0.1395]]])
tensor([[[ 0.0473, -0.0025],
[ 0.0473, -0.0025],
[ 0.0473, -0.0025],
[ 0.0473, -0.0025]]])
Looks like this commit removed the update of self.embed_avg:
ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
https://github.com/lucidrains/vector-quantize-pytorch/commit/8716f68d5549f5252d81e45651271f466d639356#diff-8bcd9c958b614ce130dc4091c094d4cfcc2023716823c181e3a9edaddbcd433dL250
@tasptz Find the same issue that embed seems not be updated during training.
@lucidrains Could u help to take a look? Why the update of emb_avg is removed after involving the multihead VQ feature.
ema_inplace(self.embed_avg, embed_sum.t(), self.decay)