torch2trt icon indicating copy to clipboard operation
torch2trt copied to clipboard

Instance normalization with FP16 has large errors

Open ivan94fi opened this issue 3 years ago • 3 comments

I have a problem with instance normalization, the model outputs diverge substantially when using the tensorrt model with float16 precision.

Versions:

  • Python: 3.9
  • Pytorch: 1.10
  • CUDA: 11.5
  • TensorRT: 8.2.1.8
  • torch2trt: 0.3.0

Example code:

from torch import nn
import torch
import random
import numpy as np

from torch2trt import torch2trt

class Net(nn.Module):
    def __init__(self, nc):
        super().__init__()
        self.conv = nn.Conv2d(nc, nc, 3)
        self.in_norm = nn.InstanceNorm2d(nc, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)

    def forward(self, x):
        x = self.conv(x)
        x = self.in_norm(x)

        return x

seed = 42
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

nc = 32
input_data = torch.rand(1, nc, 256, 256)

with torch.no_grad():
    x = input_data.cuda()

    # model = nn.Sequential(*[Net(nc) for _ in range(32)])
    model = Net(nc)
    model = model.cuda()
    model.eval()

    out = model(x)

    trt_model = torch2trt(model, [x], fp16_mode=True)
    trt_model.eval()
    trtout = trt_model(x)

    print(f"greatest difference trt (fp16): {(out - trtout).max().item()}")

    trt_model = torch2trt(model, [x])
    trt_model.eval()
    trtout = trt_model(x)

    print(f"greatest difference trt (fp32): {(out - trtout).max().item()}")

This is the output I get from this script:

greatest difference trt (fp16): 0.020292997360229492
greatest difference trt (fp32): 1.6689300537109375e-06

The errors are too big in float16 mode, especially considering this is just one instance normalization and there are many in my model, so errors are propagated and become larger.

This may be caused by the fact that instance norm uses float16 precision and this causes numerical errors.

ivan94fi avatar Jul 28 '22 14:07 ivan94fi

I'm having a similar issue, don't know if it is the FP16 overflow issue or InstanceNorm issue.

deephog avatar Aug 02 '22 00:08 deephog

Hi,

I'm having a similar issue trying to convert to FP16. The conversion was working fine with previous versions of pytorch and this started to happen when I upgraded torch from 1.8 to 1.10.

So one quick fix for you would be to downgrade to earlier version of pytorch and it might solve your problem. It is not a long term solution though because we need to upgrade at some point ! We will need to find the root of this problem aniway..

For the moment I'm trying to read the changelog to see if there is something critical that changed.

debloisg avatar Aug 25 '22 08:08 debloisg

Thanks for these information, however downgrading pythorch is not an option for us.

Any updates from the developers on this?

Thank you

ivan94fi avatar Aug 25 '22 08:08 ivan94fi