vector-quantize-pytorch
vector-quantize-pytorch copied to clipboard
kmeans and ddp hangs
kmeans and ddp hangs for me. ddp is initialized by pytorch lightning in my case. I have several questions:
In https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/vector_quantize_pytorch.py#L98
all_num_samples = all_gather_sizes(local_samples, dim = 0) should it be dim = 1 (as dim 0 is the codebook dimension)?
Then in https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/vector_quantize_pytorch.py#L93 it just hangs for me. I am not totally sure, but I believe distributed.broadcast in
https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/vector_quantize_pytorch.py#L90
is called with incompatible shapes. See https://pytorch.org/docs/stable/distributed.html#torch.distributed.broadcast
tensor must have the same number of elements in all processes participating in the collective.