TensorRT icon indicating copy to clipboard operation
TensorRT copied to clipboard

Conversion from onnx hurts accuracy when model is using fp16

Open omera-nv opened this issue 2 years ago • 10 comments

Description

I have an onnx model (a t5 encoder that I exported from pytorch) that I wish to convert to trt. This works great, but when I try to convert the model to fp16 the model's accuracy drops and it produces nothing useful. I've tried to convert the onnx model to fp16 before converting to trt and the fp16 onnx model's accuracy is good, while the conversion to trt once again hurts it badly.

Environment

TensorRT Version: 8.5.3.1 NVIDIA GPU: NVIDIA RTX A4000 NVIDIA Driver Version: 515.43.04 CUDA Version: 11.7 CUDNN Version: 8.5.0 Operating System: Ubuntu 22.04 Python Version (if applicable): 3.10 Tensorflow Version (if applicable): N/A PyTorch Version (if applicable): 2.0.0 Baremetal or Container (if so, version): Baremetal

Relevant Files

The fp32 and fp16 onnx models can be downloaded from this link: https://drive.google.com/drive/folders/1zeAW2oPP-2VwnK-SKcRVVqed30BZLMKk?usp=sharing

Steps To Reproduce

This can be reproduced using polygraphy (after making tensorrt use np.bool_ instead of np.bool):

polygraphy run t5_fp32_encoder.onnx --onnxrt --trt
polygraphy run t5_fp32_encoder.onnx --onnxrt --trt --fp16
polygraphy run t5_fp16_encoder.onnx --onnxrt --trt --fp16

The first line converts the fp32 model and works, the second and third lines convert the fp32 model to trtfp16 or the fp16 model to trt fp16 and both fail.

EDIT: just noticed that the entire output for the fp16 trt model is zeros (as can be seen by the following line in the polygraphy output:

...
[I]         trt-runner-N0-05/01/23-14:36:19: encoder_last_hidden_state | Stats: mean=2.2204e-16, std-dev=0, var=0, median=2.2204e-16, min=2.2204e-16 at (0, 0, 0), max=2.2204e-16 at (0, 0, 0), avg-magnitude=2.2204e-16
...

omera-nv avatar Apr 30 '23 21:04 omera-nv

It might be caused by LayerNorm overflow in FP16 and you should see a TRT warning when build the engine, you can try fallback the layer norm to FP32.

zerollzeng avatar May 03 '23 14:05 zerollzeng

This is a continuation of https://github.com/NVIDIA/TensorRT/issues/2899 , in which I couldn't figure out how to stop an overflowing weight from being converted. I am getting some warnings but can't manage to prevent all of the problematic layers from being converted. Also - any reason why these layers would overflow in trt but not in onnx?

omera-nv avatar May 03 '23 19:05 omera-nv

any reason why these layers would overflow in trt but not in onnx?

FP16 has a smaller range than FP32, it's cause by internal implementation, onnxruntime doesn't have much perf optimization in this case while TRT has layer fusion for many transformer structure.

I'll check it later.

zerollzeng avatar May 05 '23 07:05 zerollzeng

Could you please try TRT 8.6 GA? the result looks not very bad for me.

[I]             Absolute Difference | Stats: mean=0.0068461, std-dev=0.017569, var=0.00030868, median=0.0038256, min=2.5518e-09 at (5, 67, 193), max=0.40559 at (7, 3, 32), avg-magnitude=0.0068461
[I]                 ---- Histogram ----
                    Bin Range          |  Num Elems | Visualization
                    (2.55e-09, 0.0406) |     519560 | ########################################
                    (0.0406  , 0.0811) |       1718 |
                    (0.0811  , 0.122 ) |       1038 |
                    (0.122   , 0.162 ) |          4 |
                    (0.162   , 0.203 ) |        984 |
                    (0.203   , 0.243 ) |        492 |
                    (0.243   , 0.284 ) |          0 |
                    (0.284   , 0.324 ) |          0 |
                    (0.324   , 0.365 ) |          0 |
                    (0.365   , 0.406 ) |        492 |
[E] FAILED | Runtime: 39.566s | Command: /home/scratch.zeroz_sw/miniconda3/bin/polygraphy run t5_fp32_encoder.onnx --onnxrt --trt --fp16 --trt-opt-shapes input_ids:[8,128] attention_mask:[8,128] --input-shapes input_ids:[8,128] attention_mask:[8,128]

zerollzeng avatar May 07 '23 01:05 zerollzeng

[I]         Error Metrics: encoder_last_hidden_state
[I]             Minimum Required Tolerance: elemwise error | [abs=0.40578] OR [rel=1.8275e+15] (requirements may be lower if both abs/rel tolerances are set)
[I]             Absolute Difference | Stats: mean=0.006848, std-dev=0.017574, var=0.00030886, median=0.0038264, min=1.0194e-09 at (0, 127, 201), max=0.40578 at (7, 3, 32), avg-magnitude=0.006848
[I]                 ---- Histogram ----
                    Bin Range          |  Num Elems | Visualization
                    (1.02e-09, 0.0406) |     519564 | ########################################
                    (0.0406  , 0.0812) |       1714 |
                    (0.0812  , 0.122 ) |       1038 |
                    (0.122   , 0.162 ) |          4 |
                    (0.162   , 0.203 ) |        984 |
                    (0.203   , 0.243 ) |        492 |
                    (0.243   , 0.284 ) |          0 |
                    (0.284   , 0.325 ) |          0 |
                    (0.325   , 0.365 ) |          0 |
                    (0.365   , 0.406 ) |        492 |
[E] FAILED | Runtime: 40.150s | Command: /home/scratch.zeroz_sw/miniconda3/bin/polygraphy run t5_fp16_encoder.onnx --onnxrt --trt --fp16 --trt-opt-shapes input_ids:[8,128] attention_mask:[8,128] --input-shapes input_ids:[8,128] attention_mask:[8,128]

zerollzeng avatar May 07 '23 02:05 zerollzeng

@zerollzeng Anyway I can use 8.6 without updating to CUDA 12?

omera-nv avatar May 07 '23 06:05 omera-nv

Yes, or you can just use our tensorrt docker image.

zerollzeng avatar May 08 '23 15:05 zerollzeng

Ok, I ended up updating to cuda12 and trt 8.6.1 and I'm still getting all-zero output for some reason

omera-nv avatar May 11 '23 07:05 omera-nv

I was able to recover our transformer model from a similar situation by setting the reduce and unary layers to fp32:

for layer in network:
    if layer.type in (
        trt.LayerType.REDUCE,
        trt.LayerType.UNARY,
    ):
        layer.precision = trt.float32

The rest of the code can be found in this issue https://github.com/NVIDIA/TensorRT/issues/2899

dchebakov avatar Jul 20 '23 20:07 dchebakov

I was able to recover our transformer model from a similar situation by setting the reduce and unary layers to fp32:

for layer in network:
    if layer.type in (
        trt.LayerType.REDUCE,
        trt.LayerType.UNARY,
    ):
        layer.precision = trt.float32

The rest of the code can be found in this issue #2899

Hi, the code in #2899 is a bit nasty with so many undetermined set_flags. Could you give a version that run successfully?

niaoyu avatar Jan 12 '24 03:01 niaoyu