tvm
tvm copied to clipboard
[Bug] ELU produced inconsistent inference results with PyTorch
Actual behavior
Traceback (most recent call last):
...
np.testing.assert_allclose(torch_outputs, tvm_outputs, rtol=1e-5, atol=1e-5)
File "/workplace/software/miniconda3/envs/torch/lib/python3.7/site-packages/numpy/testing/_private/utils.py", line 1528, in assert_allclose
verbose=verbose, header=header, equal_nan=equal_nan)
File "/workplace/software/miniconda3/envs/torch/lib/python3.7/site-packages/numpy/testing/_private/utils.py", line 840, in assert_array_compare
raise AssertionError(msg)
AssertionError:
Not equal to tolerance rtol=1e-05, atol=1e-05
Mismatched elements: 1 / 102600 (0.000975%)
Max absolute difference: 2.028241e+31
Max relative difference: 1.2892904e-05
x: array([[[[[2.116557e-01, 6.871684e-01, 1.114575e+00, ...,
5.243167e-01, 1.032357e+00, 9.073991e-01],
[3.627845e+37, 1.417131e+00, 6.541348e+37, ...,...
y: array([[[[[2.116557e-01, 6.871684e-01, 1.114575e+00, ...,
5.243167e-01, 1.032357e+00, 9.073991e-01],
[3.627845e+37, 1.417131e+00, 6.541348e+37, ...,...
Environment
- TVM version: 0.13.dev0
- PyTorch: 1.4.0
Steps to reproduce
import torch
from tvm import relay
import tvm
import numpy as np
m = torch.nn.ELU(alpha=-1.8574e+38,)
input_data=torch.tensor(torch.randn([9, 12, 19, 5, 10], dtype=torch.float32))
torch_outputs = m(input_data)
trace = torch.jit.trace(m, input_data)
input_shapes = [('input0', torch.Size([9, 12, 19, 5, 10]))]
mod, params = relay.frontend.from_pytorch(trace, input_shapes)
input_data_np = input_data.numpy()
with tvm.transform.PassContext(opt_level=3):
exe = relay.create_executor('graph', mod=mod, params=params, device=tvm.device('llvm', 0), target='llvm').evaluate()
input_tvm = {'input0': tvm.nd.array(input_data_np.astype(np.float32))}
tvm_outputs = exe(**input_tvm).asnumpy()
np.testing.assert_allclose(torch_outputs, tvm_outputs, rtol=1e-5, atol=1e-5)
Triage
- needs-triage
- frontend:torch
Is this a bug in TVM? @ezyang
One thing I notice is that the ELU alpha is very very big
@ezyang @Thrsu, Could you please check whether my isolation is reasonable? I think it is a numerical bug caused by the 'FastMath' pass. Disabling 'FastMath' or enhancing the precision can solve it.
import torch
from tvm import relay
import tvm
import numpy as np
# origin question
# https://github.com/apache/tvm/issues/15396
# ---------------solution 1: disable the pass=['FastMath']
m = torch.nn.ELU(alpha=-1.8574e+38,)
input_data=torch.tensor(torch.randn([9, 12, 19, 5, 10], dtype=torch.float32))
torch_outputs = m(input_data)
trace = torch.jit.trace(m, input_data)
input_shapes = [('input0', torch.Size([9, 12, 19, 5, 10]))]
mod, params = relay.frontend.from_pytorch(trace, input_shapes)
input_data_np = input_data.numpy()
with tvm.transform.PassContext(opt_level=3,disabled_pass=['FastMath']):
exe = relay.create_executor('graph', mod=mod, params=params, device=tvm.device('llvm', 0), target='llvm').evaluate()
input_tvm = {'input0': tvm.nd.array(input_data_np.astype(np.float32))}
tvm_outputs = exe(**input_tvm).asnumpy()
np.testing.assert_allclose(torch_outputs, tvm_outputs, rtol=1e-5, atol=1e-5)
# ---------------solution 2: enhance the precision
m = torch.nn.ELU(alpha=-1.8574e+38,)
input_data=torch.tensor(torch.randn([9, 12, 19, 5, 10], dtype=torch.float64))
torch_outputs = m(input_data)
trace = torch.jit.trace(m, input_data)
input_shapes = [('input0', torch.Size([9, 12, 19, 5, 10]))]
mod, params = relay.frontend.from_pytorch(trace, input_shapes)
input_data_np = input_data.numpy()
with tvm.transform.PassContext(opt_level=3):
exe = relay.create_executor('graph', mod=mod, params=params, device=tvm.device('llvm', 0), target='llvm').evaluate()
input_tvm = {'input0': tvm.nd.array(input_data_np.astype(np.float64))}
tvm_outputs = exe(**input_tvm).asnumpy()
np.testing.assert_allclose(torch_outputs, tvm_outputs, rtol=1e-5, atol=1e-5)
How does the error arise
TVM decomposes ELU to exp-sub-relu-mul-relu. FastMath pass converts exp to fast_exp and causes the imprecision. Then the following ops enlarge the error. I think it is not really a logic problem of the transformations in TVM, but it implies that TVM may need a precision control kit.