Stark icon indicating copy to clipboard operation
Stark copied to clipboard

Why ORT_lightning_X_trt_complete didn't compare ONNX Runtime and PyTorch results?

Open JuiceLemonLemon opened this issue 3 years ago • 4 comments

Hi, I want to know why these lines are commented. Because I found that it will failed when uncomment these lines.

# # compare ONNX Runtime and PyTorch results
# np.testing.assert_allclose(to_numpy(torch_outs), ort_outs[0], rtol=1e-03, atol=1e-05)
#
# print("Exported model has been tested with ONNXRuntime, and the result looks good!")

JuiceLemonLemon avatar Nov 29 '21 05:11 JuiceLemonLemon

@JuiceLemonLemon Hi, we found that the default tolerance (relative 1e-3, absolute 1e-5) may not be suitable for our network outputs. Specifically, the testing results (whether it can pass the testing) also depend on inputs. Besides, we found that the tracking metrics of the PyTorch model and the ONNX model are almost the same. So please feel free to use the provided transformation script.

MasterBin-IIAU avatar Nov 30 '21 01:11 MasterBin-IIAU

okay, thank you for your reply. I will use this script.

JuiceLemonLemon avatar Nov 30 '21 06:11 JuiceLemonLemon

In: ORT_lightning_X_trt_complete.py

When uncommenting these lines:

# compare ONNX Runtime and PyTorch results
np.testing.assert_allclose(to_numpy(torch_outs), ort_outs[0], rtol=1e-03, atol=1e-05)
print("Exported model has been tested with ONNXRuntime, and the result looks good!")

I get the assert condition to be true:

AssertionError: 
Not equal to tolerance rtol=0.001, atol=1e-05

Mismatched elements: 4 / 4 (100%)
Max absolute difference: 0.0618192
Max relative difference: 0.09900394
 x: array([[0.528552, 0.320006, 0.686231, 0.564979]], dtype=float32)
 y: array([[0.507053, 0.29745 , 0.624411, 0.525621]], dtype=float32)

orilifs avatar Jan 18 '22 09:01 orilifs

I have the same problem:

AssertionError: 
Not equal to tolerance rtol=0.001, atol=1e-05

Mismatched elements: 4 / 4 (100%)
Max absolute difference: 0.11626142
Max relative difference: 0.2094227
 x: array([[0.477019, 0.421438, 0.680261, 0.671413]], dtype=float32)
 y: array([[0.436204, 0.351285, 0.621232, 0.555152]], dtype=float32)

Aspirinkb avatar Jan 18 '22 09:01 Aspirinkb