vector-quantize-pytorch icon indicating copy to clipboard operation
vector-quantize-pytorch copied to clipboard

[BUG] Residual VQ - self.embed.data[ind][mask] = sampled - RuntimeError: shape mismatch:

Open dwromero opened this issue 1 year ago • 10 comments
trafficstars

Hi all,

I noticed that using ResidualVQ as:

ResidualVQ(
            dim=Z_CHANNELS,  # 512
            num_quantizers=NUM_QUANTIZERS,  # 2
            codebook_size=CODEBOOK_SIZE,  # 16 * 1024
            stochastic_sample_codes=True,
            shared_codebook=True,
            commitment_weight=1.0,
            kmeans_init=True,
            threshold_ema_dead_code=2,
            quantize_dropout=True,
            quantize_dropout_cutoff_index=1,
            quantize_dropout_multiple_of=1,
        )

Leads to the following error:

File "/mnt/workspace/Projects/autoencoder.py", line 261, in forward
    z_tilde, _, commit_loss = self.vq(z)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1510, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1519, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/vector_quantize_pytorch/residual_vq.py", line 183, in forward
    quantized, *rest = layer(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1510, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1519, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/vector_quantize_pytorch/vector_quantize_pytorch.py", line 919, in forward
    quantize, embed_ind, distances = self._codebook(x, **codebook_forward_kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1510, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1519, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
    return func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/vector_quantize_pytorch/vector_quantize_pytorch.py", line 514, in forward
    self.expire_codes_(x)
  File "/usr/local/lib/python3.10/dist-packages/vector_quantize_pytorch/vector_quantize_pytorch.py", line 443, in expire_codes_
    self.replace(batch_samples, batch_mask = expired_codes)
  File "/usr/local/lib/python3.10/dist-packages/vector_quantize_pytorch/vector_quantize_pytorch.py", line 428, in replace
    self.embed.data[ind][mask] = sampled
RuntimeError: shape mismatch: value tensor of shape [9330, 512] cannot be broadcast to indexing result of shape [9331, 512]

This happens randomly during training (in a multinode setting). Any idea what the cause could be?

dwromero avatar Jun 25 '24 11:06 dwromero