TensorRT icon indicating copy to clipboard operation
TensorRT copied to clipboard

Performance Discrepancy between Quantized ONNX Model and FP16 Model

Open ThomasLesp opened this issue 2 years ago • 11 comments

Hello! I am currently working with a quantized ONNX model (using explicit quantization), and I've noticed an unexpected performance issue. The quantized model does not perform as quickly as the FP16 model, even though I'm using trtexec for both models with all necessary flags activated. Below are the links to the two models:

  1. Link to quantized ONNX model
  2. Link to full precision ONNX model

The model is an a distilbert-base-uncased fine-tuned on imdb.

The inference time

  • Full precision fp16 : 2.29ms (mean)
  • Quantized model (int 8 + fp 16) : 2.97 ms

For benchmarking the INT8 quantized model, I am using the following command: <trtexec --onnx=model_quantized.onnx --int8 --fp16>. Similarly, for the FP16 model, the following command is used:<trtexec --onnx=model_fp.onnx --fp16>.

I tried with --precisionConstraints=obey but, it didn't change anything.

My suspicion is that some optimizations, such as fused attention, might not be getting applied to the quantized model. However, I'm unsure about how to activate or ensure these optimizations are being applied.I would greatly appreciate any advice or guidance on this matter. Is there a way to apply these optimizations to the quantized model, or is there a potential workaround that could help in improving the performance?

Thank you very much for your time and assistance.

Best Regards,

Environment

Docker : nvcr.io/nvidia/tensorrt:23.06-py3

TensorRT Version: 8.6.1

NVIDIA GPU: A100

NVIDIA Driver Version: 525.105.17

CUDA Version: 12.1.1

CUDNN Version: 8.9.2

Model:

Type : a distilbert-base-uncased fine-tuned on imdb. ONNX IR version: 0.0.9 Opset version : 16

Relevant Files

Model full precision link: https://drive.google.com/file/d/1rJfISBJUMzzARCqXx3LvZBqtrs2X67fm/view?usp=sharing Model full quantized link: https://drive.google.com/file/d/1o-M0osj0GqmZAzWzogvX784os6RRdtUV/view?usp=sharing

Steps To Reproduce

Commands or scripts: trtexec --onnx=model_quantized.onnx --int8 --fp16 trtexec --onnx=model_fp.onnx --fp16

ThomasLesp avatar Jul 20 '23 17:07 ThomasLesp

Hi I have a pretty similar problem with a transformer model.In my case the model is also takes more space on gpu. Can you check device memory with polygraphy inspect model <your_engine_file>. Wonder if its also bigger for you too. Because I guess the main idea is decreasing memory foot print more than increasing speed. If its smaller on gpu it might be expected.

UlkuTuncerKucuktas avatar Jul 20 '23 20:07 UlkuTuncerKucuktas

Usually, it's caused by sub-optimal Q/QD placement, could you please refer to https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#work-with-qat-networks? Also you can compare the verbose log and check the layer-wise precision/performance to find out the reason.

I'll let @nvpohanh give a comment first to decide whether I should take a further check.

zerollzeng avatar Jul 23 '23 09:07 zerollzeng

Could you try removing the Q/DQ ops before BiasAdd? Those are not needed and may break MatMul+bias fusion.

If the performance is still worse than FP16 after removing those Q/DQs, please follow the instructions in https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#report-performance-issue to provide Nsys profiles so that we can debug this. Thanks

nvpohanh avatar Jul 24 '23 03:07 nvpohanh

Thank you for your responses and suggestions.

I have made the requested changes to the model by removing the Q/DQ ops before BiasAdd, but unfortunately, the performance is still slower than the FP16 version.

The updated quantized model can be accessed at this URL: Link to the Quantized ONNX Model

I have performed Nsys profiling for both the full precision and quantized versions, and here are the reports:

Additionally, I tried quantizing the MatMul in the attention layer of another model, but it didn't significantly impact the performance. The results were quite similar to the previous quantized model. Do you know if there are any best practices for quantizing the attention layer that could potentially help improve performance? Does the quantization process hinder the ability to fuse the attention layers and utilize Flash Attention optimizations?

Here are the relevant files for your reference:

Models

Nsys Reports

Logs

Folder

All those files are available at this link : https://drive.google.com/drive/folders/1ktu09bODugJFA1bbdWBEGA956zj41zK7?usp=sharing

Thank you once again for your assistance and insights. Best regards,

ThomasLesp avatar Jul 24 '23 14:07 ThomasLesp

2023-07-25 10_22_42-Window 2023-07-25 10_19_35-Window 2023-07-25 10_22_28-Window

Some findings:

  1. fc_qkv and fc_aout both run in high-precision. To get speed up, Q/DQ ops should be inserted before these MatMul ops.
  2. For fc_aout and fc_out, there should also be Q/DQ between the BaisAdd and ResidualAdd so that these two gemms can output in INT8 instead of high-precision.
  3. fc_mid+gelu fusion is broken because of the uneven SeqLen (437). Is it possible pad the SeqLen to multiples of 32, like 448? Uneven shapes are bad for TensorCores. You can use the attention mask to mask out the padded part.
  4. If you would like to also quantize the MHA, then insert Q/DQ before the two MatMuls (just like in the "Model with Attention Quantized") and move the Div op to be after the first MatMul in MHA. I think the uneven SeqLen may also hinder some MHA optimizations as well, so padding would be great.

@zerollzeng Could we file an internal tracker for this? I think the uneven SeqLen issue may be solved in next TRT version, but we will need to verify that internally. Thanks

nvpohanh avatar Jul 25 '23 02:07 nvpohanh

Thank you for your answer!

I've implemented the changes you requested and generated two models: one with MHA quantized and one without MHA quantized.

The performance has improved with the following inference times:

  • Quantized Model with MHA: 1.62ms (mean)
  • Model without MHA quantized: 2.02ms (mean)
  • Full Precision Model in fp16: 1.62ms (mean)

The adjustment in SeqLen also had a positive impact on the performance of the full precision model.

However, even with these optimizations, the quantized model is still not outperforming the full precision model in fp16. What can explain this? Do you see any other improvements that could be done? What do you think is the best option: quantizing the MHA and the fc_qkv, or doing the attention in fp16? How can we be sure that the plugin flash attention is used?

I encountered a CUDA failure when attempting to export the Nsys report of the model quantized with MHA, and the error message displayed Cuda failure: unspecified launch failure

The new models are available at this link: https://drive.google.com/drive/folders/1ktu09bODugJFA1bbdWBEGA956zj41zK7?usp=sharing

Models

Nsys Reports

Thanks

ThomasLesp avatar Jul 25 '23 17:07 ThomasLesp

Hmm I can't see any CUDA kernels in the MHA Quantized Nsys Report. Maybe it is caused by the Cuda failure you saw?

Does this error only happen when you run trtexec with nsys?

filed internal tracker number 4210352

nvpohanh avatar Jul 26 '23 01:07 nvpohanh

Yes exactly, I have no error when running only trtexec!

Have you already seen acceleration using int8 compare to fp16 in transformer models?

ThomasLesp avatar Aug 01 '23 16:08 ThomasLesp

For BERT we were able to see ~20% perf difference between INT8 and FP16

nvpohanh avatar Aug 29 '23 12:08 nvpohanh

closing since no activity for more than 3 weeks, pls reopen if you still have question. thanks all!

ttyio avatar Oct 10 '23 20:10 ttyio

Reopen since the internal issue 4210352 is still open. sorry...

ttyio avatar Oct 10 '23 20:10 ttyio

Will fix in TRT 10.0 @zerollzeng Please help to close this bug.

nv-samcheng avatar Mar 04 '24 09:03 nv-samcheng

closed.

zerollzeng avatar Mar 04 '24 09:03 zerollzeng