CoreML Calculation error for some parameter setup case of torch.nn.AvgPool2d
πDescribing the bug
Upon converting torch.nn.AvgPool2d, CoreML predict (both on host Mac using libcoremlpython or on device iPhone 13 Pro) gives erroneous output on parameter setup "kernel size 7, padding 3, stride 4", while the result of some other parameter setup is correct.
To Reproduce
A minimal code example:
#!/usr/bin/env python
# coding: utf-8
import torch.nn
import coremltools as ct
## 1. setup torch module
module_OK_case1 = torch.nn.AvgPool2d(
kernel_size=(3, 3),
stride=(2, 2),
padding=(1, 1),
ceil_mode=False,
count_include_pad=True
)
# case2 is a global pooling for input x (1, 1, 16, 12)
module_OK_case2 = torch.nn.AvgPool2d(
kernel_size=(16, 12),
stride=(1, 1),
padding=(0, 0),
ceil_mode=False,
count_include_pad=True
)
module_bug_reproduce = torch.nn.AvgPool2d(
kernel_size=(7, 7),
stride=(4, 4),
padding=(3, 3),
ceil_mode=False,
count_include_pad=True
)
## 2. choose the model here
# torch_model = module_OK_case1.eval()
# torch_model = module_OK_case2.eval()
torch_model = module_bug_reproduce.eval()
# it appears input shape does not matter
x = torch.randn(1, 1, 16, 12)
x = torch.randn(1, 1, 32, 32)
x = torch.randn(1, 128, 64, 64)
## 3. convert coreml mlmodel
traced_model = torch.jit.trace(torch_model, x)
mlmodel = ct.convert(traced_model, inputs=[ct.TensorType(name="input", shape=x.shape)])
spec = mlmodel.get_spec()
input_name = spec.description.input[0].name
output_name = spec.description.output[0].name
## 4. forward the same tensor with both torch and coreml
y = torch_model(x)
input_dict = { input_name : x }
coreml_out = mlmodel.predict(input_dict)
cy = coreml_out[ output_name ]
## 5. y is from torch, cy is from coreml, the diff sum is expected to be 0
print('Expected 0:', (y - cy).sum())
print('Yet 0:', (y - cy * 7*7).sum()) # 7*7 is for module_bug_reproduce
The model conversion succeeds, but there is a numerical mismatch in predictions. My output:
Expected 0: tensor(34.4464)
Yet 0: tensor(-0.0048)
It seemed that the erroneous output incorrectly divide AvgPool kernel 7x7 twice.
System environment:
- coremltools version: 6.1
- OS: MacOS Ventura 13.4.1 (22F2083)
- torch: 2.0.1
Additional context
This problem originates from a network that I was working on, and wrong calculation result is observed on an iPhone 13 Pro.
Later, I confirmed that on host Mac, coremltools model.predict yield the same result as iPhone, which is also wrong.
The first major numerical mismatch in the network is from an AvgPool2D layer, whose parameter is kernel 7x7 stride 4x4 padding 3x3, but all the results from previous AvgPool2D with other parameter is correct.
It seemed a bug for CoreML framework instead of coremltools. But I think:
- It's still worth to report it here for others may come across.
- It's possible the bug is related with
coremltoolsbecause I'm not sure what's converted is correct or not. - (Update) Already report to Apple CoreML, whose bug report website lacks a rich text editor, so that bug report has a link here for good readability.
I can reproduce this issue. This is a neural network backend only issue. If you use the new mlprogram backend, then the outputs match. To use the new backend, add to convert_to="mlprogram" your ct.convert calls.
I cannot confirm on my end.
#!/usr/bin/env python
# coding: utf-8
import torch.nn
import coremltools as ct
## 1. setup torch module
module_bug_reproduce = torch.nn.AvgPool2d(
kernel_size=(7, 7),
stride=(4, 4),
padding=(3, 3),
ceil_mode=False,
count_include_pad=True
)
torch_model = module_bug_reproduce.eval()
# it appears input shape does not matter
x = torch.randn(1, 1, 16, 12)
x = torch.randn(1, 1, 32, 32)
x = torch.randn(1, 128, 64, 64)
## 3. convert coreml mlmodel
traced_model = torch.jit.trace(torch_model, x)
mlmodel = ct.convert(
traced_model,
inputs=[ct.TensorType(name="input", shape=x.shape)],
convert_to="mlprogram",
)
spec = mlmodel.get_spec()
input_name = spec.description.input[0].name
output_name = spec.description.output[0].name
## 4. forward the same tensor with both torch and coreml
y = torch_model(x)
input_dict = { input_name : x }
coreml_out = mlmodel.predict(input_dict)
cy = coreml_out[ output_name ]
## 5. y is from torch, cy is from coreml, the diff sum is expected to be 0
print('Expected 0:', (y - cy).sum())
print('Yet 0:', (y - cy * 7*7).sum()) # 7*7 is for module_bug_reproduce
# import pdb; pdb.set_trace()
print('Done')
Full Output:
$ python minimum_reproduce.py
scikit-learn version 1.2.2 is not supported. Minimum required version: 0.17. Maximum required version: 1.1.2. Disabling scikit-learn conversion API.
Torch version 2.0.1 has not been tested with coremltools. You may run into unexpected errors. Torch 1.12.1 is the most recent version that has been tested.
Converting PyTorch Frontend ==> MIL Ops: 92%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | 12/13 [00:00<00:00, 6980.81 ops/s]
Running MIL Common passes: 0%| | 0/39 [00:00<?, ? passes/s]/Users/oliverxu/miniconda3/envs/venus/lib/python3.9/site-packages/coremltools/converters/mil/mil/passes/name_sanitization_utils.py:135: UserWarning: Output, '14', of the source model, has been renamed to 'var_14' in the Core ML model.
warnings.warn(msg.format(var.name, new_name))
Running MIL Common passes: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 39/39 [00:00<00:00, 8399.81 passes/s]
Running MIL FP16ComputePrecision pass: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 1/1 [00:00<00:00, 2118.34 passes/s]
Running MIL Clean up passes: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 11/11 [00:00<00:00, 11986.84 passes/s]
Expected 0: tensor(31.7326)
Yet 0: tensor(0.0081)
Done
Exception ignored in: <function MLModel.__del__ at 0x146faaee0>
Traceback (most recent call last):
File "/Users/oliverxu/miniconda3/envs/venus/lib/python3.9/site-packages/coremltools/models/model.py", line 369, in __del__
ImportError: sys.meta_path is None, Python is likely shutting down
And I also notice this warning:
Torch version 2.0.1 has not been tested with coremltools. You may run into unexpected errors. Torch 1.12.1 is the most recent version that has been tested.
I'll try use Torch 1.12.1 to see if things would be different.
Also
My production use .mlmodelc, and mlprogram requires a more strict iOS target. So even if mlprogram works, it's not in my best preferences.
Is there any chance that I may try to fix it ? If so, any hint for me to get started ?
Don't worry about those two warnings.
Using your most recent code, things look good to me. (y - cy).sum() being 31.7326 may seem like a lot, but given that y and cy each contain 32,768 elements, that is a mean difference of less than 0.0001.
y and cy are never going to match exactly. There are many reasons for this. One of the biggest reasons is that y and cy are likely calculated using different hardware devices. PyTorch is most likely using only the CPUs. Core ML can use the CPUs, GPUs and ANEs.
If you care a lot about having y and cy match as closely as possible, you can restrict your Core ML to run only on the CPU. To do this add compute_units=ct.ComputeUnit.CPU_ONLY to ct.convert.
My production use
.mlmodelc, andmlprogramrequires a more strict iOS target. So even ifmlprogramworks, it's not in my best preferences.
I don't understand. Why do you need to use a .mlmodelc rather than a .mlpackage? Do you need to deploy to older versions of iOS or macOS?
Using your most recent code, things look good to me.
(y - cy).sum()being31.7326may seem like a lot, but given thatyandcyeach contain 32,768 elements, that is a mean difference of less than 0.0001.
That's a good point, thanks. I'll use more criterion next time.
Why do you need to use a
.mlmodelcrather than a.mlpackage?
I'm using Objective-C, and my production has to comply to a quite old iOS target, and yes I need to deploy to older versions of iOS.
I'm not sure about using .mlpackage on Obj-C, I'll need some time to try it, so further reply may be delayed (so as for #1904 and #1905).
I was still unable to reproduce using 'mlprogram' and 7.0b1.
Snippet I was using:
#!/usr/bin/env python
# coding: utf-8
import numpy as np
import torch.nn
import coremltools as ct
class Criterion:
def __init__(self):
pass
@staticmethod
def absdiffmax(ref, test):
absmax_error = np.abs(ref - test).max()
return absmax_error
@staticmethod
def absdiff(ref, test):
# l1_error = np.sum(np.abs(ref - test)) / ref.size
l1_error = np.abs(ref - test).mean()
return l1_error
@staticmethod
def mse(ref, test):
# l2_error = np.sum(np.power((ref - test), 2)) / ref.size
mse_error = np.power((ref - test), 2).mean()
return mse_error
@staticmethod
def snr(ref, test):
# snr_error = np.sum(np.power((ref - test), 2) / (np.power(ref, 2) + 1e-6)) / ref.size
snr_error = (np.power((ref - test), 2) / (np.power(ref, 2) + 1e-6)).mean()
return snr_error
@staticmethod
def report_diff(ref_blob, blob):
absmax_error = Criterion.absdiffmax(ref_blob, blob)
l1_error = Criterion.absdiff(ref_blob, blob)
mse_error = Criterion.mse(ref_blob, blob)
snr_error = Criterion.snr(ref_blob, blob)
print(
'absmax', absmax_error.numpy(),
'L1 error', l1_error.numpy(),
'MSE error', mse_error.numpy(),
'SNR error', snr_error.numpy()
)
## 1. setup torch module
module_OK_case1 = torch.nn.AvgPool2d(
kernel_size=(3, 3),
stride=(2, 2),
padding=(1, 1),
ceil_mode=False,
count_include_pad=True
)
# case2 is a global pooling for input x (1, 1, 16, 12)
module_OK_case2 = torch.nn.AvgPool2d(
kernel_size=(16, 12),
stride=(1, 1),
padding=(0, 0),
ceil_mode=False,
count_include_pad=True
)
module_bug_reproduce = torch.nn.AvgPool2d(
kernel_size=(7, 7),
stride=(4, 4),
padding=(3, 3),
ceil_mode=False,
count_include_pad=True
)
## 2. choose the model here
# torch_model = module_OK_case1.eval()
# torch_model = module_OK_case2.eval()
torch_model = module_bug_reproduce.eval()
# it appears input shape does not matter
x = torch.randn(1, 1, 16, 12)
x = torch.randn(1, 1, 32, 32)
x = torch.randn(1, 128, 64, 64)
## 3. convert coreml mlmodel
traced_model = torch.jit.trace(torch_model, x)
mlmodel = ct.convert(
traced_model,
inputs=[ct.TensorType(name="input", shape=x.shape)],
convert_to="mlprogram",
)
spec = mlmodel.get_spec()
input_name = spec.description.input[0].name
output_name = spec.description.output[0].name
## 4. forward the same tensor with both torch and coreml
y = torch_model(x)
input_dict = { input_name : x }
coreml_out = mlmodel.predict(input_dict)
cy = coreml_out[ output_name ]
## 5. y is from torch, cy is from coreml, the diff sum is expected to be 0
Criterion.report_diff(y, cy)
Criterion.report_diff(y, cy * 7*7) # 7*7 is for module_bug_reproduce
# import pdb; pdb.set_trace()
print('Tested with coremltools version:', ct.__version__)
Output:
$ python minimum_reproduce.py
scikit-learn version 1.2.2 is not supported. Minimum required version: 0.17. Maximum required version: 1.1.2. Disabling scikit-learn conversion API.
Torch version 2.0.1 has not been tested with coremltools. You may run into unexpected errors. Torch 2.0.0 is the most recent version that has been tested.
Converting PyTorch Frontend ==> MIL Ops: 92%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | 12/13 [00:00<00:00, 6607.80 ops/s]
Running MIL frontend_pytorch pipeline: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 5/5 [00:00<00:00, 15098.29 passes/s]
Running MIL default pipeline: 0%| | 0/64 [00:00<?, ? passes/s]/Users/oliverxu/Workspace/coremltools/coremltools/converters/mil/mil/passes/defs/preprocess.py:262: UserWarning: Output, '14', of the source model, has been renamed to 'var_14' in the Core ML model.
warnings.warn(msg.format(var.name, new_name))
Running MIL default pipeline: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 64/64 [00:00<00:00, 9698.51 passes/s]
Running MIL backend_mlprogram pipeline: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 11/11 [00:00<00:00, 31622.58 passes/s]
absmax 0.68464917 L1 error 0.10804626 MSE error 0.018488485 SNR error 0.9511232
absmax 0.00037407875 L1 error 3.74514e-05 MSE error 2.580043e-09 SNR error 7.118919e-06
Tested with coremltools version: 7.0b1
This output evaluates both tensor (from torch and from coreml), prints 4 distance between them:
- the diff max in abs. It shows the biggest L1 distance of all elements in tensor.
- avg L1 distance
- avg L2 distance
- SNR, similar to 3
It showed that coreml result * 49 is almost exact as torch result
With the original code, it does work with convert_to="mlprogram".
However you're right this new combination does not work with convert_to="mlprogram". It is curious that multiplying by 49 causes things to match.
import numpy as np
import torch.nn
import coremltools as ct
module_bug_reproduce = torch.nn.AvgPool2d(
kernel_size=(7, 7),
stride=(4, 4),
padding=(3, 3),
)
torch_model = module_bug_reproduce.eval()
x = torch.randn(1, 128, 64, 64)
traced_model = torch.jit.trace(torch_model, x)
mlmodel = ct.convert(
traced_model,
inputs=[ct.TensorType(name="input", shape=x.shape)],
outputs=[ct.TensorType(name="y")],
convert_to="mlprogram",
)
y_t = torch_model(x)
y_cm = mlmodel.predict({'input': x})['y']
print(np.abs(y_t - y_cm).mean())
print(np.abs(y_t - y_cm * 49).mean())
outputs:
tensor(0.1089)
tensor(3.7489e-05)
This is a problem with one of our optimization passes. Passing pass_pipeline=ct.PassPipeline.EMPTY to ct.convert cause the predictions to match.
Passing
pass_pipeline=ct.PassPipeline.EMPTYtoct.convertcause the predictions to match.
Yes this is a valid work around, thanks.