DAQ-pytorch icon indicating copy to clipboard operation
DAQ-pytorch copied to clipboard

Inconsistent quantization format (PAMS, DAQ, CADYQ)

Open makhovds opened this issue 1 year ago • 1 comments
trafficstars

Good day, thanks for your ablation study presented in the paper. But I wonder why quantization format between repos for PAMS, DAQ and CaDYQ differs. From the PAMS paper fig.1 we expect the following quantization format in ResBlock: image From the PAMS repo we have the following code for the forward method:

def forward(self, x):
      residual = self.quant_act1(self.shortcut(x))
      body = self.body(x).mul(self.res_scale)
      res = self.quant_act3(body)
      res += residual

      return res

From DAQ repo:

def forward(self, x):        

      if self.a_bit!=32:
          out= self.quant1(x)
      else:
          out=x
      
      out = self.conv1(out)
      # if self.bn:
      #     out = self.BN1(out)


      out1 = self.act(out)


      if self.a_bit!=32:
          out1= self.quant2(out1)

      res = self.conv2(out1)
      # if self.bn:
      #     res = self.BN2(res)
      res = res.mul(self.res_scale)

      res += x
      return res

From CaDYQ repo:

def forward_ori(self, x):
      weighted_bits = x[4]
      f = x[3]
      bits = x[2]
      grad = x[0]
      x = x[1]

      x = self.shortcut(x)
      grad, x, bits, weighted_bits = self.bitsel1([grad, x, bits, weighted_bits]) # cadyq
      residual = x
      # grad,x,bits,weighted_bits= self.body[0]() # cadyq
      # x = self.body[1:3](x) # conv-relu
      x = self.body[0:2](x) # conv-relu
      # grad,x,bits,weighted_bits= self.body[3]([grad,x,bits,weighted_bits]) # cadyq
      grad,x,bits,weighted_bits= self.body[2]([grad,x,bits,weighted_bits]) # cadyq
      # out = self.body[4](x) # conv
      out = self.body[3](x) # conv
      f1 = out
      out = out.mul(self.res_scale)
      out = self.quant_act3(out)
      out += residual
      if self.loss_kdf:
          if f is None:
              f = f1.unsqueeze(0)
          else: 
              f = torch.cat([f, f1.unsqueeze(0)], dim=0)
      else:
          f = None


      return [grad, out, bits, f, weighted_bits]

So I have the following questions:

  1. Why do we add a shortcut module as it has no affect on the input?
  2. Am I right that the PAMS repo doesn't quantize the input to the first conv, while fig.1 in their paper implies quantization, so you added it in DAQ and CaDYQ?
  3. Why do you apply quantization of the final conv output? Again, fig. 1 in PAMS paper doesn't imply it, but it seems more hardware-friendly (sum of two quantized value instead of FP and quant as the latter will need broadcasting to match the data type).
  4. In DAQ paper you don't quantize skip connection before the sum (you sum with the original input x), while in CaDYQ you do (residual variable is the quantized version of x), is it correct? If yes, what is the reason for such a difference?

Thank you very much in advance for your answers!

makhovds avatar Feb 21 '24 09:02 makhovds

Hello, sorry for such a delayed response. Below are the answers for each question.

  1. Thanks for pointing out, yes, the shortcut module is not used.

  2. Yes, that is true.

  3. Following PAMS, CADyQ also applied quantization to the convolution output. However, in DAQ, we thought that quantizing the input and weight of the convolution is more critical for reducing the computational complexity, compared to the convolution output. So we did not quantize the convolution output in DAQ.

  4. Similar to the former question, quantizing the skip connection can give further reduced computational complexity, but we found that the amount is relatively insignificant. Also, we found that keeping the skip connection in FP is better for the reconstruction accuracy. So, considering these circumstances, we kept the skip connections in FP in DAQ.

Thanks for your interest and helpful feedback!

Cheeun avatar Mar 26 '24 10:03 Cheeun