Zero grad in residual vq
🐛 Bug Report
Zero grad in second residual vq as mentioned here (https://github.com/lucidrains/vector-quantize-pytorch/issues/33) https://github.com/facebookresearch/encodec/blob/194329839fd812433992272fc5e7a889176e6fd1/encodec/quantization/core_vq.py#L336
The fix link is https://github.com/lucidrains/vector-quantize-pytorch/commit/ecf2f7c7c7701029488500b6b3d46b7c895063bf
@adefossez
Thanks for bringing that out!
It seem like this won't impact the Straight-Through-Estimator gradient for the Encoder, but will kill the commitment loss for all residual VQ but the first one right ?
It seems so. But I'm not sure how much it affects the final result.
I'm a bit reluctant on introducing a change we haven't tested in this codebase, as it could change the best hyper params etc. I can add a warning however if the model is used in training mode pointing to this issue.
@adefossez @npuichigo Could you please point out into more detail why "this won't impact the Straight-Through-Estimator gradient for the Encoder"? I think if the residual is computed in a sense that doesn't pass its real gradients, then the gradient estimator may also be affected. The following code snippet may illustrate this:
import torch
def quantize(x, codebook):
diff = codebook - x # (n_code, dim)
mse = (diff**2).sum(1)
idx = torch.argmin(mse)
return codebook[idx]
dim = 5
x = torch.randn(1, dim, requires_grad=True)
codebook1 = torch.randn(10, dim)
codebook2 = torch.randn(10, dim)
q1 = quantize(x, codebook1) # quantize x with first codebook
q1 = x + (q1 - x).detach() # transplant q1's gradient to x
residual = x - q1 # detach q1 or not may make a difference. Compute residual for next level quantizing
q2 = quantize(residual, codebook2) # quantize residual with second codebook
q2 = residual + (q2 - residual).detach() # transplant q2's gradient to residual
loss = 0*q1.sum() + 1*q2.sum() # loss is a function of q1 and q2, now it is independent of q1.
loss.backward()
print(x.grad)
The printed gradient is all zero, but if we replace residual = x - q1 with residual = x - q1.detach(), the gradient will be non-zero.
why did you put 0 * q1.sum() ? that is what is breaking the STE gradient. With the current code d q1 / d x = Id and d q_i d / x = 0 for all i > 1, which is okay as the overall gradient d (sum q_i) / d x = Id which is what we want. The only thing that is impacted in the commitment loss.
Oh, I think I over-complicated the problem here. In the model, all the quantization outputs q_i are simply added to feed the decoder, so the relation d (sum q_i) / d x = Id helps making this STE still working. In my code snippet, I assume the loss function can be any arbitrary function of argument q1 and q2. In this case, the gradient from q2 will never impact the previous networks, thus may not be good.
Still, if we replace residual = x - q1 with residual = x - q1.detach(), it seems d (sum q_i) / d x = n*Id then. Thus the scale of the losses may be affected. Thanks for the clarification!
@adefossez @cantabile-kwok
If residual = residual - quantized , then the second codebook can update with gradient but it can not afffect the first codebook. If residual = residual - quantized.detach(), then the second codebook's gradient will affect the fisrt codebook.
In core_vq.py, there is the following code in VectorQuantization Class :
Now there is the following code in the ResidualVectorQuantization Class
So, this problem equals to the following problem. The following code snippet may illustrate this:
''' import torch def quantize(x, codebook): diff = codebook - x # (n_code, dim) mse = (diff**2).sum(1) idx = torch.argmin(mse) return codebook[idx]
dim = 5 x = torch.randn(1, dim, requires_grad=True) codebook1 = torch.randn(10, dim) codebook2 = torch.randn(10, dim)
q1 = quantize(x, codebook1) # quantize x with first codebook q1 = x + (q1 - x).detach() # transplant q1's gradient to x residual = x - q1.detach() # detach q1 or not may make a difference. Compute residual for next level quantizing q2 = quantize(residual, codebook2) # quantize residual with second codebook q2 = residual + (q2 - residual).detach() # transplant q2's gradient to residual
loss = 1*q2.sum() # loss is a function of q1 and q2, now it is independent of q1. loss.backward() print(x.grad) '''
if residual = x-q1, x.grad = 0, if residul = x-q1.detach(), x.grad = tensor([[1., 1., 1., 1., 1.]])