vector-quantize-pytorch icon indicating copy to clipboard operation
vector-quantize-pytorch copied to clipboard

Gumbel max trick does not seem to make sense in here

Open ivallesp opened this issue 1 year ago • 0 comments
trafficstars

Hi all,

I want to ask a question regarding some concerns I got looking at the usage of the gumbel_sample method when reinmax=False.

https://github.com/lucidrains/vector-quantize-pytorch/blob/6102e37efefefb673ebc8bec3abb02d5030dd933/vector_quantize_pytorch/vector_quantize_pytorch.py#L472-L472

First, this sampling technique is mathematically equivalent to sample from the categorical distribution, Gumbel is doing nothing here (just sampling), and the argmax makes the operation non differentiable (I know we apply STE later).

https://github.com/lucidrains/vector-quantize-pytorch/blob/6102e37efefefb673ebc8bec3abb02d5030dd933/vector_quantize_pytorch/vector_quantize_pytorch.py#L72-L77

Additionally, the logits are the codebook distances (dist in the first snippet above). It's an always positive variable, which means that it's going to be biased because it's bounded at zero. There are no gradients flowing from the sampling operation backwards (because it is not a Gumbel softmax, but a Gumbel max) hence the logits magnitude never gets altered to improve the sampling.

It seems to me that this is just takes a hidden variable (the distance matrix) normalizes it given an arbitrary temperature parameter and samples from it, adding biased noise to the straight-through relaxation... What am I missing?

ivallesp avatar Feb 23 '24 19:02 ivallesp