TensorRT
TensorRT copied to clipboard
❓ [Question] Manually Annotate Quantization Parameters in FX Graph
❓ Question
is there a way to manually annotate quantization parameters that will be respected throughout torch_tensorrt conversion (e.g. manually adding q/dq nodes, or specifying some tensor metadata) via dynamo? thank you!
cc @narendasan @peri044 maybe? 🙏
This should be possible as this is what the tensorrt model optimizer toolkit effectively does. @peri044 or @lanluo-nvidia could maybe give more specific guidance.
We currently use NVIDIA Model optimizer toolkit which inserts quantization nodes within the torch model using quantize API
- https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/9c54aa1c47871d0541801a20962996461d805162/modelopt/torch/quantization/model_quant.py#L126
- https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/9c54aa1c47871d0541801a20962996461d805162/modelopt/torch/quantization/tensor_quant.py#L229-L243 (definition of custom ops which do the quantization). We have converters for these quantization custom ops (which call Q & DQ apis in TensorRT).
You can also manually insert a quantization custom op by implementing a lowering pass which adds these nodes to the torch.fx.GraphModule and implement/register a custom converter for it. You can append custom metadata to this node by updating node.meta["val"]
- https://docs.pytorch.org/TensorRT/contributors/writing_dynamo_aten_lowering_passes.html (existing lowering passes)
- https://docs.pytorch.org/TensorRT/contributors/dynamo_converters.html This can be done outside Torch-TRT codebase using the decorations listed above to register your lowering pass/ converter.
Please let me know if you have any further questions.
hey @peri044 , thanks for the response. i tried modelopt -> export on a simple model below. am i using this wrong or missing something obvious? im using non-strict export (strict runs into torch._dynamo.exc.Unsupported: reconstruct: UserDefinedObjectVariable(_DMAttributeManager)), but hitting ValueError: Node type mismatch; expected <class 'tuple'>, but got <class 'torch.Size'>. thanks!
import modelopt.torch.quantization as mtq
import torch
from modelopt.torch.quantization.utils import export_torch_mode
class JustAConv(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
def forward(self, inputs):
return self.conv(inputs)
if __name__ == "__main__":
model = JustAConv().to("cuda").eval()
sample_input = torch.ones(1, 3, 224, 224).to("cuda")
quant_cfg = mtq.INT8_DEFAULT_CFG
mtq.quantize(
model,
quant_cfg,
forward_loop=lambda model: model(sample_input),
)
with torch.no_grad():
with export_torch_mode():
exported_program = torch.export.export(model, (sample_input,), strict=False)
@patrick-botco I have tried your example with our latest main, when strict=False it is working as expected. I guess your error might be related to your specific version. Could you please let me know your version?
hey @lanluo-nvidia thanks for checking! here are my pytorch and modelopt versions:
nvidia-modelopt 0.29.0
nvidia-modelopt-core 0.29.0
torch 2.5.1
@patrick-botco I also remembered that torch.export.export fails on strict=False somepoint around torch 2.5 If you cannot use higher torch version, then the following workaround might help to bypass the torch.export.export error.
from torch.export._trace import _export
exp_program = _export(model, (input_tensor,))
thanks @lanluo-nvidia - upgrading to torch 2.6 resolves the issue. compiling the exported program gives me something unexpected though.
for reference, the model (after mtq.quantize()) is:
JustAConv(
(conv): QuantConv2d(
3, 3, kernel_size=(3, 3), stride=(1, 1)
(input_quantizer): TensorQuantizer(8 bit fake per-tensor amax=1.0000 calibrator=MaxCalibrator quant)
(output_quantizer): TensorQuantizer(disabled)
(weight_quantizer): TensorQuantizer(8 bit fake axis=0 amax=[0.1883, 0.1920](3) calibrator=MaxCalibrator quant)
)
)
the issue: compiling the exported program
# continuing from above
trt_model = torch_tensorrt.dynamo.compile(
exported_program,
inputs=(sample_input,),
enabled_precisions={torch.int8},
min_block_size=1,
debug=True,
)
the initial lowering passes look good
graph():
%conv_weight : [num_users=1] = get_attr[target=conv.weight]
%conv_bias : [num_users=1] = get_attr[target=conv.bias]
%conv_input_quantizer__amax : [num_users=1] = get_attr[target=conv.input_quantizer._amax]
%conv_weight_quantizer__amax : [num_users=1] = get_attr[target=conv.weight_quantizer._amax]
%inputs : [num_users=1] = placeholder[target=inputs]
%quantize_op : [num_users=1] = call_function[target=torch.ops.tensorrt.quantize_op.default](args = (%inputs, %conv_input_quantizer__amax, 8, 0, False, False), kwargs = {})
%quantize_op_1 : [num_users=1] = call_function[target=torch.ops.tensorrt.quantize_op.default](args = (%conv_weight, %conv_weight_quantizer__amax, 8, 0, False, False), kwargs = {})
%convolution : [num_users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%quantize_op, %quantize_op_1, %conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
return (convolution,)
however; after constant folding, %quantize_op_1 is optimized away, resulting in %_frozen_param0
graph():
%conv_bias : [num_users=1] = get_attr[target=conv.bias]
%conv_input_quantizer__amax : [num_users=1] = get_attr[target=conv.input_quantizer._amax]
%inputs : [num_users=1] = placeholder[target=inputs]
%quantize_op : [num_users=1] = call_function[target=torch.ops.tensorrt.quantize_op.default](args = (%inputs, %conv_input_quantizer__amax, 8, 0, False, False), kwargs = {})
%_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]
%convolution : [num_users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%quantize_op, %_frozen_param0, %conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
return (convolution,)
per-channel weight quantization is not respected - it seems like %_frozen_param0 is float32 (_frozen_param0: (3, 3, 3, 3)@float32)
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node conv/convolution [aten.convolution.default] (Inputs: (quantize_op: (1, 3, 224, 224)@torch.float32, _frozen_param0: (3, 3, 3, 3)@float32, conv_bias: (3,)@float32, [1, 1], [0, 0], [1, 1], False, [0, 0], 1) | Outputs: (convolution: (1, 3, 222, 222)@torch.float32))
more importantly, the gemm kernel itself is f32f32_f32f32_f32 (obtained through torch.profiler). the i8 layout conversion of cuInt8::nchwToNcqhw4 and cuInt8::ncqhw4ToNchw makes it seem like we're doing fake quantization
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
sm80_xmma_fprop_implicit_gemm_f32f32_f32f32_f32_nchw... 0.00% 0.000us 0.00% 0.000us 0.000us 7.968us 60.73% 7.968us 7.968us 1
cuInt8::nchwToNcqhw4(float const*, unsigned int*, in... 0.00% 0.000us 0.00% 0.000us 0.000us 2.784us 21.22% 2.784us 2.784us 1
cuInt8::ncqhw4ToNchw(signed char const*, float*, int... 0.00% 0.000us 0.00% 0.000us 0.000us 2.368us 18.05% 2.368us 2.368us 1
cudaStreamCreateWithPriority 90.65% 9.377ms 90.65% 9.377ms 73.255us 0.000us 0.00% 0.000us 0.000us 128
cudaEventRecord 0.07% 7.380us 0.07% 7.380us 3.690us 0.000us 0.00% 0.000us 0.000us 2
cudaStreamWaitEvent 0.08% 8.602us 0.08% 8.602us 4.301us 0.000us 0.00% 0.000us 0.000us 2
cudaLaunchKernel 0.49% 50.595us 8.81% 911.265us 455.632us 0.000us 0.00% 0.000us 0.000us 2
Unrecognized 8.32% 860.670us 8.32% 860.670us 215.168us 0.000us 0.00% 0.000us 0.000us 4
cuLaunchKernelEx 0.07% 7.510us 0.07% 7.510us 7.510us 0.000us 0.00% 0.000us 0.000us 1
cudaDeviceSynchronize 0.31% 32.371us 0.31% 32.371us 32.371us 0.000us 0.00% 0.000us 0.000us 1
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
do you happen to know what the issue is? am i using this wrong / missing something? thanks! cc @peri044 @narendasan as well 🙏
i am using these versions to test:
torch 2.6.0
torch_tensorrt 2.6.0
nvidia-modelopt 0.29.0
nvidia-modelopt-core 0.29.0
@patrick-botco yes, it should avoid constant folding, the fix is already in another PR(https://github.com/pytorch/TensorRT/blob/f989864d01321e20c9f7e536ad324aaffadc009b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py#L105), Let me first create a separate bug fixing PR for you, so that it can be merged to main asap.
Here is the PR raised: https://github.com/pytorch/TensorRT/pull/3543 I have verified with your example that the kernel invoked is: sm80_xmma_fprop_implicit_gemm_interleaved_i8f32_i8i32_f32
Name: conv.weight_quantizer/quantize_op_1 + [QUANTIZE]-[aten_ops.quantize_op.default]-[conv.weight_quantizer/quantize_op_1_quantize] + [CONVOLUTION]-[aten_ops.convolution.default]-[conv/convolution], LayerType: CaskConvolution, Inputs: [ { Name: (Unnamed Layer* 1) [Quantize]_output, Location: Device, Dimensions: [1,3,224,224], Format/Datatype: Int8 }, { Name: (Unnamed Layer* 7) [Constant]_output, Location: Device, Dimensions: [3], Format/Datatype: Float }], Outputs: [ { Name: output0, Location: Device, Dimensions: [1,3,222,222], Format/Datatype: Float }], ParameterType: Convolution, Kernel: [3,3], PaddingMode: kEXPLICIT_ROUND_DOWN, PrePadding: [0,0], PostPadding: [0,0], Stride: [1,1], Dilation: [1,1], OutMaps: 3, Groups: 1, Weights: {"Type": "Int8", "Count": 81}, Bias: {"Type": "Float", "Count": 0}, HasBias: 0, HasReLU: 0, HasSparseWeights: 0, HasDynamicFilter: 0, HasDynamicBias: 1, HasResidual: 0, ConvXAsActInputIdx: -1, BiasAsActInputIdx: -1, ResAsActInputIdx: -1, Activation: NONE, TacticName: sm80_xmma_fprop_implicit_gemm_interleaved_i8f32_i8i32_f32_nchw_vect_c_32kcrs_vect_c_32_nchw_tilesize128x32x64_stage4_warpsize4x1x1_g1_tensor16x8x32_t1r3s3_alignc4, TacticValue: 0xa8b56a226b057463, StreamId: 0, Metadata:
thanks so much @lanluo-nvidia !