TensorRT icon indicating copy to clipboard operation
TensorRT copied to clipboard

Int8 quantized model performs worse or similar to non quantized fp32 or fp16 model.

Open Raj-vivid opened this issue 1 year ago • 4 comments

I am using a pretrained model from timm for convnextv2. It comprises of layer norm and globalresponsenormalization layer but even after adding custom quant modules for layer norm , layer norm 2d and global response norm (grn) I still can't make my model run faster than base model with fp16 engine. I am using python extension for tensorrt and using model OPT to perform the quantization.

My code for creating custom modules are as follows :

class QuantLayerNorm(LayerNorm):
    def __init__(self, normalized_shape):
        super().__init__(normalized_shape)
        self._setup()

    def _setup(self):
        self.input_quantizer = TensorQuantizer()
        self.weight_quantizer = TensorQuantizer()

    def forward(self, input):
        input = self.input_quantizer(input)
        weight = self.weight_quantizer(self.weight)
        return F.layer_norm(input, self.normalized_shape, weight, self.bias, self.eps)
    

class QuantLayerNorm2d(LayerNorm2d):
    def __init__(self, normalized_shape):
        super().__init__(normalized_shape)
        self._setup()

    def _setup(self):
        self.input_quantizer = TensorQuantizer()
        self.weight_quantizer = TensorQuantizer()

    def forward(self, input):
        input = self.input_quantizer(input)
        weight = self.weight_quantizer(self.weight)
        input = input.permute(0, 2, 3, 1)
        input = F.layer_norm(input, self.normalized_shape, weight, self.bias, self.eps)
        input = input.permute(0, 3, 1, 2)
        return input
    

class QuantGlobalResponseNorm(GlobalResponseNorm):
    """Quantized Global Response Normalization layer with Tensor Quantizers."""
    
    def __init__(self, dim, eps=1e-6, channels_last=True):
        super().__init__()
        self.eps = eps
        if channels_last:
            self.spatial_dim = (1, 2)
            self.channel_dim = -1
            self.wb_shape = (1, 1, 1, -1)
        else:
            self.spatial_dim = (2, 3)
            self.channel_dim = 1
            self.wb_shape = (1, -1, 1, 1)


        self.weight = nn.Parameter(torch.zeros(dim))
        self.bias = nn.Parameter(torch.zeros(dim))
        
        # Setup quantizers
        self._setup()
    
    def _setup(self):
        self.input_quantizer = TensorQuantizer()
        self.weight_quantizer = TensorQuantizer()
        
        # self.bias_quantizer = TensorQuantizer()
    
    def forward(self, x):
        x = self.input_quantizer(x)
        quant_weight = self.weight_quantizer(self.weight)

        # quant_bias = self.bias_quantizer(self.bias)
        quant_bias = self.bias  

        x_g = x.norm(p=2, dim=self.spatial_dim, keepdim=True)
        x_n = x_g / (x_g.mean(dim=self.channel_dim, keepdim=True) + self.eps)
        return x + torch.addcmul(
            quant_bias.view(self.wb_shape), 
            quant_weight.view(self.wb_shape), 
            x * x_n
        )
    

mtq.register(original_cls=LayerNorm, quantized_cls=QuantLayerNorm)
mtq.register(original_cls=LayerNorm2d, quantized_cls=QuantLayerNorm2d)
mtq.register(original_cls=GlobalResponseNorm, quantized_cls=QuantGlobalResponseNorm)

I am using the following config :

{'quant_cfg': {'*weight_quantizer': {'num_bits': 8, 'axis': 0},
  '*input_quantizer': {'num_bits': 8, 'axis': None},
  '*lm_head*': {'enable': False},
  '*block_sparse_moe.gate*': {'enable': False},
  '*router*': {'enable': False},
  '*output_layer*': {'enable': False},
  'output.*': {'enable': False},
  'nn.BatchNorm1d': {'*': {'enable': False}},
  'nn.BatchNorm2d': {'*': {'enable': False}},
  'nn.BatchNorm3d': {'*': {'enable': False}},
  'nn.LeakyReLU': {'*': {'enable': False}},
  'default': {'enable': False},
  '*output_quantizer': {'num_bits': 8, 'axis': None},
  'LayerNorm2d': {'*': {'enable': True}},
  'LayerNorm': {'*': {'enable': True}},
  'GlobalResponseNorm': {'*': {'enable': True}}},
 'algorithm': 'max'}

I am going by the documentation but it is not clear to me if I am doing something wrong. Help is much appreciated.

Raj-vivid avatar Oct 02 '24 19:10 Raj-vivid

How about your ptq predict performs ?

lix19937 avatar Oct 05 '24 06:10 lix19937

I tried this with ViTb and convnext, I setup a convnext pipeline in this notebook https://drive.google.com/file/d/1LTfJsAcTgJ3Rb8BXiAuC66OD-_9zEWSy/view?usp=drive_link

This is actually using PTQ. Int8 does not seem to give any boosts in performance over fp16 and in some cases causes a slight slowdown.

Rajjeshwar avatar Oct 07 '24 19:10 Rajjeshwar

I tried a few more things, using onnx opset 17 to convert to onnx allowed for using a fused node for Layernorm but even with qdq before conv operation, before GELU activation and before Layernorm in my graph I still get much slower speed than fp16.

Before: Image

After: Image

Raj-vivid avatar Oct 09 '24 13:10 Raj-vivid

Hello, I appreciate the help you do for the community to answer all the issues on the thread, could you please have a look at this ?

No valid tactics for finetuned_convnext.convnext.stages.0.blocks.0.mlp.fc1.weight + [/finetuned_convnext/convnext/stages/stages.0/blocks/blocks.0/mlp/fc1/weight_quantizer/QuantizeLinear](https://vscode-remote+ssh-002dremote-002bcompute-005fnode.vscode-resource.vscode-cdn.net/finetuned_convnext/convnext/stages/stages.0/blocks/blocks.0/mlp/fc1/weight_quantizer/QuantizeLinear) + [/finetuned_convnext/convnext/stages/stages.0/blocks/blocks.0/mlp/fc1/Conv](https://vscode-remote+ssh-002dremote-002bcompute-005fnode.vscode-resource.vscode-cdn.net/finetuned_convnext/convnext/stages/stages.0/blocks/blocks.0/mlp/fc1/Conv) + PWN(PWN(PWN(PWN(PWN(PWN(PWN(/finetuned_convnext/convnext/stages/stages.0/blocks/blocks.0/mlp/act/Mul,

I keep getting this when using mtq.INT8_DEFAULT_CFG. Does it mean tensorRT does not support fusing layer norm with conv and gelu activation? Still haven't been able to find the slowdown.

Raj-vivid avatar Oct 15 '24 00:10 Raj-vivid

Are you still facing this problem? For faster triage are you able to provide your fp16 and int8 models?

kevinch-nv avatar Feb 11 '25 20:02 kevinch-nv

Closing due to inactive. Please feel free to reopen!

poweiw avatar May 29 '25 21:05 poweiw