DAQ-pytorch
DAQ-pytorch copied to clipboard
Inconsistent quantization format (PAMS, DAQ, CADYQ)
Good day, thanks for your ablation study presented in the paper. But I wonder why quantization format between repos for PAMS, DAQ and CaDYQ differs.
From the PAMS paper fig.1 we expect the following quantization format in ResBlock:
From the PAMS repo we have the following code for the forward method:
def forward(self, x):
residual = self.quant_act1(self.shortcut(x))
body = self.body(x).mul(self.res_scale)
res = self.quant_act3(body)
res += residual
return res
From DAQ repo:
def forward(self, x):
if self.a_bit!=32:
out= self.quant1(x)
else:
out=x
out = self.conv1(out)
# if self.bn:
# out = self.BN1(out)
out1 = self.act(out)
if self.a_bit!=32:
out1= self.quant2(out1)
res = self.conv2(out1)
# if self.bn:
# res = self.BN2(res)
res = res.mul(self.res_scale)
res += x
return res
From CaDYQ repo:
def forward_ori(self, x):
weighted_bits = x[4]
f = x[3]
bits = x[2]
grad = x[0]
x = x[1]
x = self.shortcut(x)
grad, x, bits, weighted_bits = self.bitsel1([grad, x, bits, weighted_bits]) # cadyq
residual = x
# grad,x,bits,weighted_bits= self.body[0]() # cadyq
# x = self.body[1:3](x) # conv-relu
x = self.body[0:2](x) # conv-relu
# grad,x,bits,weighted_bits= self.body[3]([grad,x,bits,weighted_bits]) # cadyq
grad,x,bits,weighted_bits= self.body[2]([grad,x,bits,weighted_bits]) # cadyq
# out = self.body[4](x) # conv
out = self.body[3](x) # conv
f1 = out
out = out.mul(self.res_scale)
out = self.quant_act3(out)
out += residual
if self.loss_kdf:
if f is None:
f = f1.unsqueeze(0)
else:
f = torch.cat([f, f1.unsqueeze(0)], dim=0)
else:
f = None
return [grad, out, bits, f, weighted_bits]
So I have the following questions:
- Why do we add a shortcut module as it has no affect on the input?
- Am I right that the PAMS repo doesn't quantize the input to the first conv, while fig.1 in their paper implies quantization, so you added it in DAQ and CaDYQ?
- Why do you apply quantization of the final conv output? Again, fig. 1 in PAMS paper doesn't imply it, but it seems more hardware-friendly (sum of two quantized value instead of FP and quant as the latter will need broadcasting to match the data type).
- In DAQ paper you don't quantize skip connection before the sum (you sum with the original input x), while in CaDYQ you do (residual variable is the quantized version of x), is it correct? If yes, what is the reason for such a difference?
Thank you very much in advance for your answers!
Hello, sorry for such a delayed response. Below are the answers for each question.
-
Thanks for pointing out, yes, the shortcut module is not used.
-
Yes, that is true.
-
Following PAMS, CADyQ also applied quantization to the convolution output. However, in DAQ, we thought that quantizing the input and weight of the convolution is more critical for reducing the computational complexity, compared to the convolution output. So we did not quantize the convolution output in DAQ.
-
Similar to the former question, quantizing the skip connection can give further reduced computational complexity, but we found that the amount is relatively insignificant. Also, we found that keeping the skip connection in FP is better for the reconstruction accuracy. So, considering these circumstances, we kept the skip connections in FP in DAQ.
Thanks for your interest and helpful feedback!