tvm icon indicating copy to clipboard operation
tvm copied to clipboard

[Bug] ELU produced inconsistent inference results with PyTorch

Open Thrsu opened this issue 2 years ago • 2 comments

Actual behavior

Traceback (most recent call last):
 ...
    np.testing.assert_allclose(torch_outputs, tvm_outputs, rtol=1e-5, atol=1e-5)
  File "/workplace/software/miniconda3/envs/torch/lib/python3.7/site-packages/numpy/testing/_private/utils.py", line 1528, in assert_allclose
    verbose=verbose, header=header, equal_nan=equal_nan)
  File "/workplace/software/miniconda3/envs/torch/lib/python3.7/site-packages/numpy/testing/_private/utils.py", line 840, in assert_array_compare
    raise AssertionError(msg)
AssertionError: 
Not equal to tolerance rtol=1e-05, atol=1e-05

Mismatched elements: 1 / 102600 (0.000975%)
Max absolute difference: 2.028241e+31
Max relative difference: 1.2892904e-05
 x: array([[[[[2.116557e-01, 6.871684e-01, 1.114575e+00, ...,
           5.243167e-01, 1.032357e+00, 9.073991e-01],
          [3.627845e+37, 1.417131e+00, 6.541348e+37, ...,...
 y: array([[[[[2.116557e-01, 6.871684e-01, 1.114575e+00, ...,
           5.243167e-01, 1.032357e+00, 9.073991e-01],
          [3.627845e+37, 1.417131e+00, 6.541348e+37, ...,...

Environment

  • TVM version: 0.13.dev0
  • PyTorch: 1.4.0

Steps to reproduce

import torch
from tvm import relay
import tvm
import numpy as np

m = torch.nn.ELU(alpha=-1.8574e+38,)
input_data=torch.tensor(torch.randn([9, 12, 19, 5, 10], dtype=torch.float32))
torch_outputs = m(input_data)
trace = torch.jit.trace(m, input_data)
input_shapes = [('input0', torch.Size([9, 12, 19, 5, 10]))]

mod, params = relay.frontend.from_pytorch(trace, input_shapes)
input_data_np = input_data.numpy()
with tvm.transform.PassContext(opt_level=3):
    exe = relay.create_executor('graph', mod=mod, params=params, device=tvm.device('llvm', 0), target='llvm').evaluate()
input_tvm = {'input0': tvm.nd.array(input_data_np.astype(np.float32))}
tvm_outputs = exe(**input_tvm).asnumpy()
np.testing.assert_allclose(torch_outputs, tvm_outputs, rtol=1e-5, atol=1e-5)

Triage

  • needs-triage
  • frontend:torch

Is this a bug in TVM? @ezyang

Thrsu avatar Jul 24 '23 17:07 Thrsu

One thing I notice is that the ELU alpha is very very big

ezyang avatar Jul 24 '23 17:07 ezyang

@ezyang @Thrsu, Could you please check whether my isolation is reasonable? I think it is a numerical bug caused by the 'FastMath' pass. Disabling 'FastMath' or enhancing the precision can solve it.

import torch
from tvm import relay
import tvm
import numpy as np

# origin question 
# https://github.com/apache/tvm/issues/15396

# ---------------solution 1: disable the pass=['FastMath']
m = torch.nn.ELU(alpha=-1.8574e+38,)
input_data=torch.tensor(torch.randn([9, 12, 19, 5, 10], dtype=torch.float32))
torch_outputs = m(input_data)
trace = torch.jit.trace(m, input_data)
input_shapes = [('input0', torch.Size([9, 12, 19, 5, 10]))]

mod, params = relay.frontend.from_pytorch(trace, input_shapes)
input_data_np = input_data.numpy()


with tvm.transform.PassContext(opt_level=3,disabled_pass=['FastMath']):
    exe = relay.create_executor('graph', mod=mod, params=params, device=tvm.device('llvm', 0), target='llvm').evaluate()
input_tvm = {'input0': tvm.nd.array(input_data_np.astype(np.float32))}
tvm_outputs = exe(**input_tvm).asnumpy()
np.testing.assert_allclose(torch_outputs, tvm_outputs, rtol=1e-5, atol=1e-5)

# ---------------solution 2: enhance the precision


m = torch.nn.ELU(alpha=-1.8574e+38,)
input_data=torch.tensor(torch.randn([9, 12, 19, 5, 10], dtype=torch.float64))
torch_outputs = m(input_data)
trace = torch.jit.trace(m, input_data)
input_shapes = [('input0', torch.Size([9, 12, 19, 5, 10]))]

mod, params = relay.frontend.from_pytorch(trace, input_shapes)
input_data_np = input_data.numpy()


with tvm.transform.PassContext(opt_level=3):
    exe = relay.create_executor('graph', mod=mod, params=params, device=tvm.device('llvm', 0), target='llvm').evaluate()
input_tvm = {'input0': tvm.nd.array(input_data_np.astype(np.float64))}
tvm_outputs = exe(**input_tvm).asnumpy()
np.testing.assert_allclose(torch_outputs, tvm_outputs, rtol=1e-5, atol=1e-5)

How does the error arise

TVM decomposes ELU to exp-sub-relu-mul-relu. FastMath pass converts exp to fast_exp and causes the imprecision. Then the following ops enlarge the error. I think it is not really a logic problem of the transformations in TVM, but it implies that TVM may need a precision control kit.

hxzd5568 avatar Feb 08 '24 14:02 hxzd5568