Large output errors when converting MobileNetV3 in FP16
🐞Describing the bug
I trained the default torchvision implementation of MobileNetV3Large and convert it to CoreML. In fp32 both models give identical results but when converted to fp16 the CoreML output differs much more than the Torch model. I created a self contained example to show the difference on grayscale images (from 0 to 255), but the errors are equally large on my real data.
To Reproduce
import asyncio
import coremltools as ct
import matplotlib.pyplot as plt
import numpy as np
import torch
from coremltools.converters.mil.mil import types
from coremltools.models.ml_program.experimental.debugging_utils import MLModelComparator
from PIL import Image
from tabulate import tabulate
from torchvision.models import MobileNet_V3_Large_Weights, mobilenet_v3_large
from tqdm import tqdm
with torch.no_grad():
model_torch32 = mobilenet_v3_large(weights=MobileNet_V3_Large_Weights.DEFAULT)
model_torch32.eval()
model_torch16 = mobilenet_v3_large(weights=MobileNet_V3_Large_Weights.DEFAULT)
model_torch16.eval().to(torch.float16)
model_coreml32 = ct.converters.convert(
model=torch.jit.trace(model_torch32, torch.zeros((1, 3, 224, 224))),
inputs=[ct.ImageType(name="image", shape=(1, 3, 224, 224), scale=1 / 255.0)],
outputs=[ct.TensorType(name="y", dtype=types.fp32)],
minimum_deployment_target=ct.target.iOS18,
compute_units=ct.ComputeUnit.CPU_ONLY,
compute_precision=ct.precision.FLOAT32,
)
model_coreml16 = ct.converters.convert(
model=torch.jit.trace(model_torch32, torch.zeros((1, 3, 224, 224))),
inputs=[ct.ImageType(name="image", shape=(1, 3, 224, 224), scale=1 / 255.0)],
outputs=[ct.TensorType(name="y", dtype=types.fp16)],
minimum_deployment_target=ct.target.iOS18,
compute_units=ct.ComputeUnit.CPU_ONLY,
compute_precision=ct.precision.FLOAT16,
)
# Compare results
xs = list(range(0, 256))
torch32_preds = []
torch16_preds = []
coreml32_preds = []
coreml16_preds = []
for x in tqdm(xs):
x_torch = torch.full((1, 3, 224, 224), x, dtype=torch.float32) / 255.0
torch32_preds.append(model_torch32(x_torch)[0, 0].item())
torch16_preds.append(model_torch16(x_torch.to(torch.float16))[0, 0].item())
x_coreml = Image.fromarray(np.full((224, 224, 3), x, dtype=np.uint8))
coreml32_preds.append(model_coreml32.predict({"image": x_coreml})["y"][0, 0])
coreml16_preds.append(model_coreml16.predict({"image": x_coreml})["y"][0, 0])
# Print error table
print(
tabulate(
[
[
"Torch ",
np.mean(np.abs(np.array(torch32_preds) - np.array(torch16_preds))),
np.mean((np.array(torch32_preds) - np.array(torch16_preds)) ** 2),
],
[
"CoreML",
np.mean(np.abs(np.array(coreml32_preds) - np.array(coreml16_preds))),
np.mean((np.array(coreml32_preds) - np.array(coreml16_preds)) ** 2),
],
],
headers=["Model", "L1", "L2"],
)
)
# plot
plt.plot(xs, torch32_preds, label="Torch32")
plt.plot(xs, torch16_preds, label="Torch16")
plt.plot(xs, coreml32_preds, label="CoreML32", linestyle="dashed")
plt.plot(xs, coreml16_preds, label="CoreML16", linestyle="dashed")
plt.legend()
plt.title("MobileNetV3: Torch vs CoreML")
plt.xlabel("Input pixel value")
plt.ylabel("Output value")
plt.savefig("torch_vs_coreml.png")
plt.close()
fp32 vs fp16 comparison for Torch and CoreML:
Model L1 Error L2 Error
------- --------- -----------
Torch 0.0164849 0.000468726
CoreML 0.0393638 0.00248957 <-- much larger
Comparison plot, CoreML in fp16 output is extremely noisy:
I also tried MLModelComparator to find the error but it fails on normal "linear" and "conv" layers (depending on atol):
def compare_outputs(operation, reference_output, target_output):
return np.allclose(reference_output, target_output, atol=1e-1)
comparator = MLModelComparator(reference_model=model_coreml32, target_model=model_coreml16, num_predict_intermediate_outputs=720)
failing_ops = asyncio.run(
comparator.find_failing_ops(inputs={"image": x_coreml}, compare_outputs=compare_outputs)
)
print(failing_ops)
Analyzed operation: classifier_3_weight, type: const: 1%|▏ | 4/720 [00:00<02:18, 5.18it/s]
[type: "linear"
inputs {
key: "x"
value {
arguments {
name: "input_331"
}
}
}
inputs {
key: "weight"
value {
arguments {
name: "classifier_3_weight"
}
}
}
inputs {
key: "bias"
value {
arguments {
name: "classifier_3_bias"
}
}
}
outputs {
name: "y"
type {
tensorType {
dataType: FLOAT32
rank: 2
dimensions {
constant {
size: 1
}
}
dimensions {
constant {
size: 1000
}
}
}
}
}
attributes {
key: "name"
value {
type {
tensorType {
dataType: STRING
}
}
immediateValue {
tensor {
strings {
values: "linear_1"
}
}
}
}
}
]
System environment
coremltools version: 9.0 OS: macOS 26.1 PyTorch version: 2.7.0
Why does CoreML fp16 have much lower precision than PyTorch float16?
import coremltools as ct
import matplotlib.pyplot as plt
import numpy as np
import torch
from coremltools.converters.mil.mil import types
from tqdm import tqdm
INPUT_DIM = 2**12
class DummyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.net = torch.nn.Linear(INPUT_DIM, 1)
self.net.weight.data.fill_(1.0 / self.net.weight.numel())
self.net.bias.data.fill_(0.0)
def forward(self, x):
return self.net(x)
with torch.no_grad():
model_torch32 = DummyModel()
model_torch32.eval()
model_torch16 = DummyModel()
model_torch16.eval().to(torch.float16)
model_coreml32 = ct.converters.convert(
model=torch.jit.trace(model_torch32, torch.randn((1, INPUT_DIM))),
inputs=[ct.TensorType(name="x", shape=(1, INPUT_DIM))],
outputs=[ct.TensorType(name="y", dtype=types.fp32)],
minimum_deployment_target=ct.target.iOS18,
compute_units=ct.ComputeUnit.CPU_ONLY,
compute_precision=ct.precision.FLOAT32,
)
model_coreml16 = ct.converters.convert(
model=torch.jit.trace(model_torch32, torch.randn((1, INPUT_DIM))),
inputs=[ct.TensorType(name="x", shape=(1, INPUT_DIM))],
outputs=[ct.TensorType(name="y", dtype=types.fp16)],
minimum_deployment_target=ct.target.iOS18,
compute_units=ct.ComputeUnit.CPU_ONLY,
compute_precision=ct.precision.FLOAT16,
)
# Compare results
xs = np.linspace(-5, 5, 101)
torch32_preds = []
torch16_preds = []
coreml32_preds = []
coreml16_preds = []
for x in tqdm(xs):
x_torch = torch.full((1, INPUT_DIM), x, dtype=torch.float32)
torch32_preds.append(model_torch32(x_torch)[0].item())
torch16_preds.append(model_torch16(x_torch.to(torch.float16))[0].item())
x_coreml = np.full((1, INPUT_DIM), x, dtype=np.float32)
coreml32_preds.append(model_coreml32.predict({"x": x_coreml})["y"][0].item())
coreml16_preds.append(
model_coreml16.predict({"x": x_coreml.astype(np.float16)})["y"][0].item()
)
# plot
plt.plot(xs, torch32_preds, label="Torch32")
plt.plot(xs, torch16_preds, label="Torch16")
plt.plot(xs, coreml32_preds, label="CoreML32", linestyle="--")
plt.plot(xs, coreml16_preds, label="CoreML16", linestyle=":")
plt.legend()
plt.title("Torch vs CoreML")
plt.xlabel("Input value")
plt.ylabel("Output value")
plt.savefig("torch_vs_coreml.png")
plt.close()
This is something I have observed as well -- we are currently unable to use float16 weights because of this issue.
To make matters worse, issue fluctuates between MacOS versions & chips (e.g. my MacBook Air M1 has a high numerical error for specific ops as compared to MacBook M1 Max, which works fine).
@johan-sightic @dj-nuo @jgibson2 In the project I'm currently working on, there is also a significant error with the Linear operator during NPU prediction. When I tried using FP32, the results were correct, and through Xcode I confirmed that the computing devices were CPU & GPU. Later, I ran the FP16 model on CPU & GPU, and the results were also correct. Therefore, I believe this issue is related to the device—there seems to be some difference when computing on the NPU.
@hongkun-Shao , as an example, I'm experiencing high numerical error:
- MacBook Air M1, MacOS 13.6. Model A, FP16:
- CPU_ONLY: works correctly
- CPU_AND_GPU: bad predictions
- MacBook Pro M1 Max, MacOS 15.7.1. Model B, FP16
- CPU_ONLY: bad predictions (WHAT???😫)
- CPU_AND_GPU: works correctly
CoreML FP32 seems to work fine, but for my models it's 2x slower inference time, so it's faster to just use PyTorch directly. If I run both of these models as PyTorch on MPS, inference time is ~20-30% slower than CoreML converted model. For now, I'm fine with this tradeoff, as CoreML does not look production-ready for everybody (if you're lucky - your model is not affected). And MLModelComparator is exactly the proof of this.
My models seem to not support ANE (Apple Neural Engine) - it doesn't even finish ct.convert() and freezes indefinetely. So can't comment on NPU.
@hongkun-Shao I am mostly using a "Mac Mini M1, MacOS 26.1" and I get the same numerical error in fp16 regardles of device CPU/GPU/NPU :confused:
@hongkun-Shao , as an example, I'm experiencing high numerical error:
MacBook Air M1, MacOS 13.6. Model A, FP16:
- CPU_ONLY: works correctly
- CPU_AND_GPU: bad predictions
MacBook Pro M1 Max, MacOS 15.7.1. Model B, FP16
- CPU_ONLY: bad predictions (WHAT???😫)
- CPU_AND_GPU: works correctly
CoreML FP32 seems to work fine, but for my models it's 2x slower inference time, so it's faster to just use PyTorch directly. If I run both of these models as PyTorch on MPS, inference time is ~20-30% slower than CoreML converted model. For now, I'm fine with this tradeoff, as CoreML does not look production-ready for everybody (if you're lucky - your model is not affected). And MLModelComparator is exactly the proof of this.
My models seem to not support ANE (Apple Neural Engine) - it doesn't even finish
ct.convert()and freezes indefinetely. So can't comment on NPU.
I saw this same behavior. After bisecting the model there were a few things that caused this, but the one I couldn't fix is referenced here: https://github.com/apple/coremltools/issues/2621
I welcome the discussion here, but this isn't something which can be fixed in the coremltools repository. This a Core ML framework issue. Please submit a bug report here: https://developer.apple.com/bug-reporting/
Hi @TobyRoseman, have you noticed this problem as well? Do you get the same result if you run any of the sample code posted here?
To expand on @johan-sightic's nn.Linear example, here's that same fp16 model converted to an mlpackage and run in a basic host app on macOS and iOS.
macOS 26.1 / MBP M2 Pro Precision looks very good on CPU_AND_GPU and ALL, a bit lossy on CPU_AND_NE and extremely lossy on CPU_ONLY.
iOS 26.1 / iPhone 17 Pro Here the results mirror the CPU_ONLY results on macOS. I would suspect this does not imply that all compute units show the same precision loss on iOS, but rather that Core ML chose to schedule the task on CPU only for some reason?
Xcode project: CoreMLDebug.zip