vector-quantize-pytorch
vector-quantize-pytorch copied to clipboard
Potential Bug for Lower codebook Dimension/ get_output_from_indices (improved VQGAN)
trafficstars
Dear @lucidrains,
throughout my experimentations with this wonderful library, I found some weird behaviour when using lower codebook dimensions;
# everything okay
import torch
from vector_quantize_pytorch import VectorQuantize
vq = VectorQuantize(
dim = 256,
codebook_size = 256,
codebook_dim = 16 # paper proposes setting this to 32 or as low as 8 to increase codebook usage
)
vq.eval()
x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = vq(x)
When I try to recover the output from the indices, the script crashes:
# RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x16384 and 16x256).
output = vq.get_output_from_indices(indices)
print(output.shape)
I tracked this down to get_codes_from_indices.
if not is_multiheaded:
codes = codebook[indices]
return rearrange(codes, '... h d -> ... (h d)') # why is this line required?
The reason is that we need to call self.project_out for dim != codebook_dim.
If I remove rearrange(codes, '... h d -> ... (h d)'), everything works as expected:
# codes = codebook[indices] # shape (1, 1024, 16)
proj_out = vq.project_out(codes)
# returns True
torch.all(quantized == proj_out)
Please find the full example on Colab. In case I did a mistake, I apologize, I am still new to PyTorch/ this library...
Thanks, Nikolai