encodec icon indicating copy to clipboard operation
encodec copied to clipboard

Zero grad in residual vq

Open npuichigo opened this issue 3 years ago • 8 comments

🐛 Bug Report

Zero grad in second residual vq as mentioned here (https://github.com/lucidrains/vector-quantize-pytorch/issues/33) https://github.com/facebookresearch/encodec/blob/194329839fd812433992272fc5e7a889176e6fd1/encodec/quantization/core_vq.py#L336

The fix link is https://github.com/lucidrains/vector-quantize-pytorch/commit/ecf2f7c7c7701029488500b6b3d46b7c895063bf

npuichigo avatar Dec 26 '22 11:12 npuichigo

@adefossez

npuichigo avatar Jan 06 '23 06:01 npuichigo

Thanks for bringing that out!

It seem like this won't impact the Straight-Through-Estimator gradient for the Encoder, but will kill the commitment loss for all residual VQ but the first one right ?

adefossez avatar Jan 11 '23 14:01 adefossez

It seems so. But I'm not sure how much it affects the final result.

npuichigo avatar Jan 11 '23 14:01 npuichigo

I'm a bit reluctant on introducing a change we haven't tested in this codebase, as it could change the best hyper params etc. I can add a warning however if the model is used in training mode pointing to this issue.

adefossez avatar Jan 24 '23 12:01 adefossez

@adefossez @npuichigo Could you please point out into more detail why "this won't impact the Straight-Through-Estimator gradient for the Encoder"? I think if the residual is computed in a sense that doesn't pass its real gradients, then the gradient estimator may also be affected. The following code snippet may illustrate this:

import torch
def quantize(x, codebook):
    diff = codebook - x  # (n_code, dim)
    mse = (diff**2).sum(1)
    idx = torch.argmin(mse)
    return codebook[idx]

dim = 5
x = torch.randn(1, dim, requires_grad=True)
codebook1 = torch.randn(10, dim)
codebook2 = torch.randn(10, dim)

q1 = quantize(x, codebook1)  # quantize x with first codebook
q1 = x + (q1 - x).detach()  # transplant q1's gradient to x
residual = x - q1  # detach q1 or not may make a difference. Compute residual for next level quantizing
q2 = quantize(residual, codebook2)  # quantize residual with second codebook
q2 = residual + (q2 - residual).detach()  # transplant q2's gradient to residual

loss = 0*q1.sum() + 1*q2.sum()  # loss is a function of q1 and q2, now it is independent of q1.
loss.backward()
print(x.grad)

The printed gradient is all zero, but if we replace residual = x - q1 with residual = x - q1.detach(), the gradient will be non-zero.

cantabile-kwok avatar Apr 04 '23 13:04 cantabile-kwok

why did you put 0 * q1.sum() ? that is what is breaking the STE gradient. With the current code d q1 / d x = Id and d q_i d / x = 0 for all i > 1, which is okay as the overall gradient d (sum q_i) / d x = Id which is what we want. The only thing that is impacted in the commitment loss.

adefossez avatar Apr 04 '23 13:04 adefossez

Oh, I think I over-complicated the problem here. In the model, all the quantization outputs q_i are simply added to feed the decoder, so the relation d (sum q_i) / d x = Id helps making this STE still working. In my code snippet, I assume the loss function can be any arbitrary function of argument q1 and q2. In this case, the gradient from q2 will never impact the previous networks, thus may not be good.

Still, if we replace residual = x - q1 with residual = x - q1.detach(), it seems d (sum q_i) / d x = n*Id then. Thus the scale of the losses may be affected. Thanks for the clarification!

cantabile-kwok avatar Apr 04 '23 14:04 cantabile-kwok

@adefossez @cantabile-kwok

If residual = residual - quantized , then the second codebook can update with gradient but it can not afffect the first codebook. If residual = residual - quantized.detach(), then the second codebook's gradient will affect the fisrt codebook.

In core_vq.py, there is the following code in VectorQuantization Class : image

Now there is the following code in the ResidualVectorQuantization Class image

So, this problem equals to the following problem. The following code snippet may illustrate this:

''' import torch def quantize(x, codebook): diff = codebook - x # (n_code, dim) mse = (diff**2).sum(1) idx = torch.argmin(mse) return codebook[idx]

dim = 5 x = torch.randn(1, dim, requires_grad=True) codebook1 = torch.randn(10, dim) codebook2 = torch.randn(10, dim)

q1 = quantize(x, codebook1) # quantize x with first codebook q1 = x + (q1 - x).detach() # transplant q1's gradient to x residual = x - q1.detach() # detach q1 or not may make a difference. Compute residual for next level quantizing q2 = quantize(residual, codebook2) # quantize residual with second codebook q2 = residual + (q2 - residual).detach() # transplant q2's gradient to residual

loss = 1*q2.sum() # loss is a function of q1 and q2, now it is independent of q1. loss.backward() print(x.grad) '''

if residual = x-q1, x.grad = 0, if residul = x-q1.detach(), x.grad = tensor([[1., 1., 1., 1., 1.]])

DingWeiPeng avatar Jan 09 '24 11:01 DingWeiPeng