vector-quantize-pytorch
vector-quantize-pytorch copied to clipboard
[BUG] Residual VQ - self.embed.data[ind][mask] = sampled - RuntimeError: shape mismatch:
trafficstars
Hi all,
I noticed that using ResidualVQ as:
ResidualVQ(
dim=Z_CHANNELS, # 512
num_quantizers=NUM_QUANTIZERS, # 2
codebook_size=CODEBOOK_SIZE, # 16 * 1024
stochastic_sample_codes=True,
shared_codebook=True,
commitment_weight=1.0,
kmeans_init=True,
threshold_ema_dead_code=2,
quantize_dropout=True,
quantize_dropout_cutoff_index=1,
quantize_dropout_multiple_of=1,
)
Leads to the following error:
File "/mnt/workspace/Projects/autoencoder.py", line 261, in forward
z_tilde, _, commit_loss = self.vq(z)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1510, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1519, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/vector_quantize_pytorch/residual_vq.py", line 183, in forward
quantized, *rest = layer(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1510, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1519, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/vector_quantize_pytorch/vector_quantize_pytorch.py", line 919, in forward
quantize, embed_ind, distances = self._codebook(x, **codebook_forward_kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1510, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1519, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
return func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/vector_quantize_pytorch/vector_quantize_pytorch.py", line 514, in forward
self.expire_codes_(x)
File "/usr/local/lib/python3.10/dist-packages/vector_quantize_pytorch/vector_quantize_pytorch.py", line 443, in expire_codes_
self.replace(batch_samples, batch_mask = expired_codes)
File "/usr/local/lib/python3.10/dist-packages/vector_quantize_pytorch/vector_quantize_pytorch.py", line 428, in replace
self.embed.data[ind][mask] = sampled
RuntimeError: shape mismatch: value tensor of shape [9330, 512] cannot be broadcast to indexing result of shape [9331, 512]
This happens randomly during training (in a multinode setting). Any idea what the cause could be?