TensorRT icon indicating copy to clipboard operation
TensorRT copied to clipboard

❓ [Question] Manually Annotate Quantization Parameters in FX Graph

Open patrick-botco opened this issue 6 months ago • 11 comments

❓ Question

is there a way to manually annotate quantization parameters that will be respected throughout torch_tensorrt conversion (e.g. manually adding q/dq nodes, or specifying some tensor metadata) via dynamo? thank you!

patrick-botco avatar May 16 '25 07:05 patrick-botco

cc @narendasan @peri044 maybe? 🙏

patrick-botco avatar May 16 '25 07:05 patrick-botco

This should be possible as this is what the tensorrt model optimizer toolkit effectively does. @peri044 or @lanluo-nvidia could maybe give more specific guidance.

narendasan avatar May 16 '25 14:05 narendasan

We currently use NVIDIA Model optimizer toolkit which inserts quantization nodes within the torch model using quantize API

  1. https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/9c54aa1c47871d0541801a20962996461d805162/modelopt/torch/quantization/model_quant.py#L126
  2. https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/9c54aa1c47871d0541801a20962996461d805162/modelopt/torch/quantization/tensor_quant.py#L229-L243 (definition of custom ops which do the quantization). We have converters for these quantization custom ops (which call Q & DQ apis in TensorRT).

You can also manually insert a quantization custom op by implementing a lowering pass which adds these nodes to the torch.fx.GraphModule and implement/register a custom converter for it. You can append custom metadata to this node by updating node.meta["val"]

  1. https://docs.pytorch.org/TensorRT/contributors/writing_dynamo_aten_lowering_passes.html (existing lowering passes)
  2. https://docs.pytorch.org/TensorRT/contributors/dynamo_converters.html This can be done outside Torch-TRT codebase using the decorations listed above to register your lowering pass/ converter.

Please let me know if you have any further questions.

peri044 avatar May 19 '25 19:05 peri044

hey @peri044 , thanks for the response. i tried modelopt -> export on a simple model below. am i using this wrong or missing something obvious? im using non-strict export (strict runs into torch._dynamo.exc.Unsupported: reconstruct: UserDefinedObjectVariable(_DMAttributeManager)), but hitting ValueError: Node type mismatch; expected <class 'tuple'>, but got <class 'torch.Size'>. thanks!

import modelopt.torch.quantization as mtq
import torch
from modelopt.torch.quantization.utils import export_torch_mode


class JustAConv(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(3, 3, 3)

    def forward(self, inputs):
        return self.conv(inputs)


if __name__ == "__main__":
    model = JustAConv().to("cuda").eval()
    sample_input = torch.ones(1, 3, 224, 224).to("cuda")
    quant_cfg = mtq.INT8_DEFAULT_CFG
    mtq.quantize(
        model,
        quant_cfg,
        forward_loop=lambda model: model(sample_input),
    )

    with torch.no_grad():
        with export_torch_mode():
            exported_program = torch.export.export(model, (sample_input,), strict=False)

patrick-botco avatar May 26 '25 23:05 patrick-botco

@patrick-botco I have tried your example with our latest main, when strict=False it is working as expected. I guess your error might be related to your specific version. Could you please let me know your version?

lanluo-nvidia avatar May 29 '25 22:05 lanluo-nvidia

hey @lanluo-nvidia thanks for checking! here are my pytorch and modelopt versions:

nvidia-modelopt           0.29.0
nvidia-modelopt-core      0.29.0
torch                     2.5.1

patrick-botco avatar May 30 '25 05:05 patrick-botco

@patrick-botco I also remembered that torch.export.export fails on strict=False somepoint around torch 2.5 If you cannot use higher torch version, then the following workaround might help to bypass the torch.export.export error.

from torch.export._trace import _export
exp_program = _export(model, (input_tensor,))

lanluo-nvidia avatar May 30 '25 21:05 lanluo-nvidia

thanks @lanluo-nvidia - upgrading to torch 2.6 resolves the issue. compiling the exported program gives me something unexpected though.

for reference, the model (after mtq.quantize()) is:

JustAConv(
  (conv): QuantConv2d(
    3, 3, kernel_size=(3, 3), stride=(1, 1)
    (input_quantizer): TensorQuantizer(8 bit fake per-tensor amax=1.0000 calibrator=MaxCalibrator quant)
    (output_quantizer): TensorQuantizer(disabled)
    (weight_quantizer): TensorQuantizer(8 bit fake axis=0 amax=[0.1883, 0.1920](3) calibrator=MaxCalibrator quant)
  )
)

the issue: compiling the exported program

            # continuing from above
            trt_model = torch_tensorrt.dynamo.compile(
                exported_program,
                inputs=(sample_input,),
                enabled_precisions={torch.int8},
                min_block_size=1,
                debug=True,
            )

the initial lowering passes look good

graph():
    %conv_weight : [num_users=1] = get_attr[target=conv.weight]
    %conv_bias : [num_users=1] = get_attr[target=conv.bias]
    %conv_input_quantizer__amax : [num_users=1] = get_attr[target=conv.input_quantizer._amax]
    %conv_weight_quantizer__amax : [num_users=1] = get_attr[target=conv.weight_quantizer._amax]
    %inputs : [num_users=1] = placeholder[target=inputs]
    %quantize_op : [num_users=1] = call_function[target=torch.ops.tensorrt.quantize_op.default](args = (%inputs, %conv_input_quantizer__amax, 8, 0, False, False), kwargs = {})
    %quantize_op_1 : [num_users=1] = call_function[target=torch.ops.tensorrt.quantize_op.default](args = (%conv_weight, %conv_weight_quantizer__amax, 8, 0, False, False), kwargs = {})
    %convolution : [num_users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%quantize_op, %quantize_op_1, %conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
    return (convolution,)

however; after constant folding, %quantize_op_1 is optimized away, resulting in %_frozen_param0

graph():
    %conv_bias : [num_users=1] = get_attr[target=conv.bias]
    %conv_input_quantizer__amax : [num_users=1] = get_attr[target=conv.input_quantizer._amax]
    %inputs : [num_users=1] = placeholder[target=inputs]
    %quantize_op : [num_users=1] = call_function[target=torch.ops.tensorrt.quantize_op.default](args = (%inputs, %conv_input_quantizer__amax, 8, 0, False, False), kwargs = {})
    %_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]
    %convolution : [num_users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%quantize_op, %_frozen_param0, %conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
    return (convolution,)

per-channel weight quantization is not respected - it seems like %_frozen_param0 is float32 (_frozen_param0: (3, 3, 3, 3)@float32)

INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node conv/convolution [aten.convolution.default] (Inputs: (quantize_op: (1, 3, 224, 224)@torch.float32, _frozen_param0: (3, 3, 3, 3)@float32, conv_bias: (3,)@float32, [1, 1], [0, 0], [1, 1], False, [0, 0], 1) | Outputs: (convolution: (1, 3, 222, 222)@torch.float32))

more importantly, the gemm kernel itself is f32f32_f32f32_f32 (obtained through torch.profiler). the i8 layout conversion of cuInt8::nchwToNcqhw4 and cuInt8::ncqhw4ToNchw makes it seem like we're doing fake quantization

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
sm80_xmma_fprop_implicit_gemm_f32f32_f32f32_f32_nchw...         0.00%       0.000us         0.00%       0.000us       0.000us       7.968us        60.73%       7.968us       7.968us             1  
cuInt8::nchwToNcqhw4(float const*, unsigned int*, in...         0.00%       0.000us         0.00%       0.000us       0.000us       2.784us        21.22%       2.784us       2.784us             1  
cuInt8::ncqhw4ToNchw(signed char const*, float*, int...         0.00%       0.000us         0.00%       0.000us       0.000us       2.368us        18.05%       2.368us       2.368us             1  
                           cudaStreamCreateWithPriority        90.65%       9.377ms        90.65%       9.377ms      73.255us       0.000us         0.00%       0.000us       0.000us           128  
                                        cudaEventRecord         0.07%       7.380us         0.07%       7.380us       3.690us       0.000us         0.00%       0.000us       0.000us             2  
                                    cudaStreamWaitEvent         0.08%       8.602us         0.08%       8.602us       4.301us       0.000us         0.00%       0.000us       0.000us             2  
                                       cudaLaunchKernel         0.49%      50.595us         8.81%     911.265us     455.632us       0.000us         0.00%       0.000us       0.000us             2  
                                           Unrecognized         8.32%     860.670us         8.32%     860.670us     215.168us       0.000us         0.00%       0.000us       0.000us             4  
                                       cuLaunchKernelEx         0.07%       7.510us         0.07%       7.510us       7.510us       0.000us         0.00%       0.000us       0.000us             1  
                                  cudaDeviceSynchronize         0.31%      32.371us         0.31%      32.371us      32.371us       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------

do you happen to know what the issue is? am i using this wrong / missing something? thanks! cc @peri044 @narendasan as well 🙏

i am using these versions to test:

torch                     2.6.0
torch_tensorrt            2.6.0
nvidia-modelopt           0.29.0
nvidia-modelopt-core      0.29.0

patrick-botco avatar Jun 01 '25 07:06 patrick-botco

@patrick-botco yes, it should avoid constant folding, the fix is already in another PR(https://github.com/pytorch/TensorRT/blob/f989864d01321e20c9f7e536ad324aaffadc009b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py#L105), Let me first create a separate bug fixing PR for you, so that it can be merged to main asap.

lanluo-nvidia avatar Jun 01 '25 17:06 lanluo-nvidia

Here is the PR raised: https://github.com/pytorch/TensorRT/pull/3543 I have verified with your example that the kernel invoked is: sm80_xmma_fprop_implicit_gemm_interleaved_i8f32_i8i32_f32

Name: conv.weight_quantizer/quantize_op_1 + [QUANTIZE]-[aten_ops.quantize_op.default]-[conv.weight_quantizer/quantize_op_1_quantize] + [CONVOLUTION]-[aten_ops.convolution.default]-[conv/convolution], LayerType: CaskConvolution, Inputs: [ { Name: (Unnamed Layer* 1) [Quantize]_output, Location: Device, Dimensions: [1,3,224,224], Format/Datatype: Int8 }, { Name: (Unnamed Layer* 7) [Constant]_output, Location: Device, Dimensions: [3], Format/Datatype: Float }], Outputs: [ { Name: output0, Location: Device, Dimensions: [1,3,222,222], Format/Datatype: Float }], ParameterType: Convolution, Kernel: [3,3], PaddingMode: kEXPLICIT_ROUND_DOWN, PrePadding: [0,0], PostPadding: [0,0], Stride: [1,1], Dilation: [1,1], OutMaps: 3, Groups: 1, Weights: {"Type": "Int8", "Count": 81}, Bias: {"Type": "Float", "Count": 0}, HasBias: 0, HasReLU: 0, HasSparseWeights: 0, HasDynamicFilter: 0, HasDynamicBias: 1, HasResidual: 0, ConvXAsActInputIdx: -1, BiasAsActInputIdx: -1, ResAsActInputIdx: -1, Activation: NONE, TacticName: sm80_xmma_fprop_implicit_gemm_interleaved_i8f32_i8i32_f32_nchw_vect_c_32kcrs_vect_c_32_nchw_tilesize128x32x64_stage4_warpsize4x1x1_g1_tensor16x8x32_t1r3s3_alignc4, TacticValue: 0xa8b56a226b057463, StreamId: 0, Metadata:

lanluo-nvidia avatar Jun 01 '25 18:06 lanluo-nvidia

thanks so much @lanluo-nvidia !

patrick-botco avatar Jun 02 '25 00:06 patrick-botco