vector-quantize-pytorch icon indicating copy to clipboard operation
vector-quantize-pytorch copied to clipboard

No way of training the codebook

Open RafailFridman opened this issue 3 years ago • 5 comments
trafficstars

Hi! Could you please explain how the codebook vectors are updated if the codebook vectors are not required to be orthogonal?

  1. embed tensors in both Euclidean and CosineSim codebooks are registered as buffers, so they can't be updated at all
  2. There is no loss on the codebook vectors that moves them closer to the input

Am I missing something? It seems that right now there is no way of updating the codebook vectors without the orthogonal loss.

RafailFridman avatar Jul 26 '22 08:07 RafailFridman

@RafailFridman they are updated through an exponential moving average of the cluster statistics during training shown here

lucidrains avatar Jul 26 '22 20:07 lucidrains

It seems totally broken to me. Please see the following code and output

import torch
torch.manual_seed(3026076450)
x = torch.randn((1, 4, 2))

from vector_quantize_pytorch import VectorQuantize
vq = VectorQuantize(
    dim = 2,
    codebook_size = 4,     # codebook size
    decay = 0.1,             # the exponential moving average decay, lower means the dictionary will change faster
    commitment_weight = 0.   # the weight on the commitment loss
)

print("codebook")
print(vq.codebook)
print(vq._codebook.embed_avg)

for _ in range(1000):
    y, *_ = vq(x)

print("codebook")
print(vq.codebook)
print(vq._codebook.embed_avg)

print("x, quantized x")
print(x)
print(y)
codebook
tensor([[ 0.8001, -0.2969],
        [-0.6842, -0.4783],
        [-0.6681,  0.4355],
        [ 0.1892, -0.0102]])
tensor([[[ 0.8001, -0.2969],
         [-0.6842, -0.4783],
         [-0.6681,  0.4355],
         [ 0.1892, -0.0102]]])
codebook
tensor([[ 8.0007e+04, -2.9693e+04],
        [-6.8416e+04, -4.7833e+04],
        [-6.6815e+04,  4.3555e+04],
        [ 4.7304e-02, -2.5412e-03]])
tensor([[[ 0.8001, -0.2969],
         [-0.6842, -0.4783],
         [-0.6681,  0.4355],
         [ 0.1892, -0.0102]]])
x, quantized x
tensor([[[ 0.8369,  0.9166],
         [ 0.1462,  0.1903],
         [ 0.3105, -0.5587],
         [ 0.1310,  0.1395]]])
tensor([[[ 0.0473, -0.0025],
         [ 0.0473, -0.0025],
         [ 0.0473, -0.0025],
         [ 0.0473, -0.0025]]])

tasptz avatar Sep 30 '22 07:09 tasptz

Looks like this commit removed the update of self.embed_avg:

ema_inplace(self.embed_avg, embed_sum.t(), self.decay)

https://github.com/lucidrains/vector-quantize-pytorch/commit/8716f68d5549f5252d81e45651271f466d639356#diff-8bcd9c958b614ce130dc4091c094d4cfcc2023716823c181e3a9edaddbcd433dL250

tasptz avatar Sep 30 '22 07:09 tasptz

@tasptz Find the same issue that embed seems not be updated during training.

npuichigo avatar Nov 24 '22 04:11 npuichigo

@lucidrains Could u help to take a look? Why the update of emb_avg is removed after involving the multihead VQ feature.

ema_inplace(self.embed_avg, embed_sum.t(), self.decay)

npuichigo avatar Nov 24 '22 04:11 npuichigo