vector-quantize-pytorch
vector-quantize-pytorch copied to clipboard
How about ChamferDistance instead of cdist to calculate KNN?
Hi, I propose to use KNN function that is used in point cloud field instead of torch.cdist.
VectorQuantization finds nearest vector inside of the codebook (B, M, D) with given vector input (B, N, D).
This repository is using torch.cdist to calculate nearest codebook (e.g. (B, N, D), (B, M, D) -> index of (B, N)) in here to get similarity matrix (B, N, M).
This method requires to preserve similarity matrix (B, N, M) in the memory. However this is inefficient if the N or M became larger. Unlike this, during the calculation of ChanferDistance, it does not preserve the full similarity matrix (B, N, M) in the memory using reduction operation of CUDA. For example, it gets two input sequence of vectors (B, N, D) and (B, M, D) and it can directly return (B, N, D), and (B, N; int64), where the first output (B, N, D) has same shape with the input vector but its values are of the codebook, and (B, N; int64) is the indices of nearest codebook. It is similar with how memory efficient attention reduces required memory of Transformer calculation.
This KNN implementation is available in off-the-shelf library such as pytorch3d.ops.knn(), and it is differentiable and DDP-safe, and it works on both CPU and GPU.
@Kitsunetic hey! thanks for proposing this
i was actually looking into kmeans++
so basically this is identical to kmeans except more memory efficient?
@Kitsunetic since kmeans is only calculated at the start, i don't think it really affects training all that much?