onnx-tensorrt icon indicating copy to clipboard operation
onnx-tensorrt copied to clipboard

[FixBug] Result Inconsistency By clip(opset < 11) with Integer Inputs

Open ganler opened this issue 3 years ago • 17 comments

Fix Conversion in Clip (opset < 11)

Consider clip(0, 1) in ONNX opeset 10 (Clip-6):

image-20211213190302944

Symptom

Result inconsistency! It won't report any error but silently return wrong results.

image-20211213190449090

Root Cause

In opset < 11, Clip's attributes ("min" and "max") are float types even if its input might be integers/float16, etc. See https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Clip-6

However, during conversion, onnx-trt will simply assume that the attribute types ("min" and "max") are the same type as input tensor, thus resulting in casting a floating-point protobuf binary interpreted as an integer which is undefined behaviour.

Therefore, clip(min=0.0, max=1.0) will regarded as clip(min=0, max=0) during conversion.

To prove it, I put some logs around result parsing

https://github.com/onnx/onnx-tensorrt/blob/85e79f629fb546a75d61e3027fb259a9529144fe/builtin_op_importers.cpp#L425

image-20211213191142818 image-20211213191203395

As you can see by casting it to float, we get the correct value. The original implementation will cast it to an integer that returns an undefined value and won't alert any error.

After the change

image-20211213194816178

To reproduce the bug

Please use this model generated by PyTorch.

import torch
import onnx
import tensorrt as trt
import onnx
import pycuda.driver as cuda
import numpy as np
import pycuda.autoinit


class Model(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()

    @torch.no_grad()
    def forward(self, x):
        y = torch.clip(x, 0, 1)
        return y


model = Model()
inputs = (torch.randn(10,).to(torch.int32),)
torch.onnx.export(model, inputs, "output.onnx", verbose=False,
                  input_names=["input"], opset_version=10, output_names=["output"])

ganler avatar Dec 14 '21 01:12 ganler

@kevinch-nv Hi Kevin, do you want to take a look? Thanks.

ganler avatar Dec 14 '21 02:12 ganler

@zerollzeng @kevinch-nv Could you help review this PR? :-)

ganler avatar Jan 06 '22 22:01 ganler

Hi @ganler Thanks for the PR, we don't have much PR on this repo now and so Kevin may haven't noticed this. I can help create an internal PR for this if you want. or I can help contact someone to review this, but it may take some time. which one do you prefer?

zerollzeng avatar Jan 07 '22 12:01 zerollzeng

Both work for me. Just want to help the community notice this bug. :-)

ganler avatar Jan 07 '22 15:01 ganler

Hi @ganler , I checked https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Clip-6

The input of the clip operator with the INT32 datatype is invalid.

T : tensor(float16), tensor(float), tensor(double)
Constrain input and output types to float tensors.

if you run the model with onnx-runtime, you will get error like: failed:This is an invalid model. Type Error: Type 'tensor(int32)' of input parameter (input) of operator (Clip) in node (Clip_0) is invalid.

zerollzeng avatar Jan 08 '22 06:01 zerollzeng

But I think we can optimize here, since the onnx specification restrict the min and max must be float, so it may still cause some accuracy issue here with float16 or double input. I will check it further.

zerollzeng avatar Jan 08 '22 06:01 zerollzeng

@zerollzeng Yes, your observation is right. This will be helpful for float16 and double.

Another thing I found is that PyTorch will support those data types though ONNX spec does not support them (checker from ONNX will pass them as well). So I still think we should improve it here. :-)

ganler avatar Jan 09 '22 04:01 ganler

Another thing I found is that PyTorch will support those data types though ONNX spec does not support them (checker from ONNX will pass them as well). So I still think we should improve it here. :-)

I'm curious why onnx checker doesn't report errors on this while the onnxruntime throws the error and refused to execute it. For the second question, TensorRT's onnx parser will only follow onnx specs and we won't enlarge the scope.

zerollzeng avatar Jan 09 '22 04:01 zerollzeng

I'm curious why onnx checker doesn't report errors on this while the onnxruntime throws the error and refused to execute it.

IMHO, ONNX checker only performs very relaxed checking, though it is adopted by many frameworks for "verification" purposes. Other inference engines are more strict about the spec including the ORT example you mentioned. Another example is TVM for PReLU: https://github.com/pytorch/pytorch/issues/70570

For the second question, TensorRT's onnx parser will only follow onnx specs and we won't enlarge the scope.

This makes sense, I think we can forbid those invalid types. I will update the PR soon.

ganler avatar Jan 10 '22 20:01 ganler

CLA assistant check
All committers have signed the CLA.

CLAassistant avatar Jan 19 '22 20:01 CLAassistant

CLA signed. Sorry about the delay.

ganler avatar Jan 24 '22 18:01 ganler

@ganler if the ONNX checker is not failing for an invalid model, you can file an issue in github.com/onnx/onnx about that.

garymm avatar Feb 01 '22 18:02 garymm

@ganler if the ONNX checker is not failing for an invalid model, you can file an issue in github.com/onnx/onnx about that.

I see. Thank you for the suggestion. But I doubt as such checkings are widely ignored according to my experience. For example, https://github.com/pytorch/pytorch/pull/72401

ganler avatar Feb 07 '22 03:02 ganler

Hi @zerollzeng @kevinch-nv, may I know if there should be any follow-ups regarding this PR? Thanks! :-)

ganler avatar Feb 07 '22 03:02 ganler

I suggest file a PR on github.com/onnx/onnxso the onnx checker can clarify this error unless TRT has issues on FP16 or double precision inputs.

zerollzeng avatar Feb 07 '22 04:02 zerollzeng

I suggest file a PR on github.com/onnx/onnxso the onnx checker can clarify this error unless TRT has issues on FP16 or double precision inputs.

@zerollzeng From https://github.com/onnx/onnx/issues/3995#issuecomment-1032048610, it seems that I was using an older ONNX tool where such errors are made silent. I tried the newest ONNX tool and it can successfully report the error. Altought PyTorch did not use full_check=True at the point that I reported this issue, it seems they just allowed full check 2 weeks ago: https://github.com/pytorch/pytorch/pull/71125.

ganler avatar Feb 07 '22 23:02 ganler

@zerollzeng I did some digging. Clip won't affect double/fp16 as they are will be substituted by min(max(...)). This might not be an issue if people only use latest PyTorch. (Older PyTorch version will trigger this bug). That said the root cause is that older PyTorch creates models that unsat the prior ONNX spec.

ganler avatar Feb 13 '22 07:02 ganler