High Numerical Errors in Mish Activation with FLOAT16 Precision on Neural Engine
πDescribing the bug
The built-in Mish activation function in coremltools introduces significant numerical errors in Core ML models when using 16-bit floating point precision (FLOAT16) on configurations with ComputeUnit=CPU_AND_NE. Specifically, converting models that utilize the Mish activation results in substantial discrepancies in output predictions compared to the original model, leading to high error rates across various metrics.
Stack Trace
N/A
To Reproduce
Follow the steps below to reproduce the high numerical errors using the built-in Mish activation function:
-
Clone the KataGo repository:
git clone --branch v1.15.3-coreml1 https://github.com/ChinChangYang/KataGo.git KataGo-v1.15.3-coreml1 cd KataGo-v1.15.3-coreml1/python -
Download a KataGo model in RAW checkpoint format:
wget https://media.katagotraining.org/uploaded/networks/zips/kata1/kata1-b18c384nbt-s9996604416-d4316597426.zip unzip kata1-b18c384nbt-s9996604416-d4316597426.zip ln -s kata1-b18c384nbt-s9996604416-d4316597426/model.ckpt model.ckpt -
Install Python Modules:
pip install torch coremltools matplotlib -
Evaluate the high error using the built-in Mish implementation:
wget https://gist.githubusercontent.com/ChinChangYang/529ccdffb90b60d307550b067f2fbab8/raw/abc3050cfad77e1ec87c92f61bd4b8c1b4f6cc28/testcoremlerror_original.py python testcoremlerror_original.pyExpected Output:
Mean Absolute Errors Across Samples: var_2572: FLOAT16: 1.042287 FLOAT32: 0.000095 linear_9: FLOAT16: 3.587491 FLOAT32: 0.000245 linear_10: FLOAT16: 2.812497 FLOAT32: 0.000182 linear_11: FLOAT16: 2.498940 FLOAT32: 0.000269 var_2631: FLOAT16: 0.079012 FLOAT32: 0.000011 -
Evaluate the lower error using the alternative Mish implementation:
wget https://gist.githubusercontent.com/ChinChangYang/b9d45f13a40ff738baa607a265a0b2c3/raw/8bf3ae8e66946451be7dbd0d6debdae9d8e82fcf/testcoremlerror_workaround.py python testcoremlerror_workaround.pyExpected Output:
Mean Absolute Errors Across Samples: var_2572: FLOAT16: 0.008898 FLOAT32: 0.000395 linear_9: FLOAT16: 0.018509 FLOAT32: 0.000812 linear_10: FLOAT16: 0.014011 FLOAT32: 0.000628 linear_11: FLOAT16: 0.016918 FLOAT32: 0.000859 var_2631: FLOAT16: 0.001414 FLOAT32: 0.000036
System environment (please complete the following information):
- coremltools version: 8.0
- OS: MacOS 15.0
- Any other relevant version information:
- PyTorch: 2.4.1
Additional context
The issue arises specifically when using ComputeUnit=CPU_AND_NE with Precision=FLOAT16. The built-in Mish activation function in coremltools leads to high numerical errors, as evidenced by metrics such as winrateError, leadError, and others showing discrepancies upwards of 25%. Switching to an alternative Mish implementation drastically reduces these errors to below 1%, albeit with a 32% increase in inference time due to the additional operators introduced.
This problem is isolated to 16-bit floating point precision on the Neural Engine (NE), as experiments with other compute units and precision settings (e.g., FLOAT32) do not exhibit the same high error rates. The significant reduction in error using the alternative Mish implementation suggests that the built-in Mish operator may have implementation issues when used in this specific configuration.
This issue was generated based on a detailed analysis of numerical errors in Core ML models using the Mish activation function with 16-bit precision, as documented in the related blog post. Further investigation and collaboration from the coremltools engineering team would be greatly appreciated to resolve this matter.
I write the alternative Mish implementation here:
def mish_torch_sigmoid(context, node):
inputs = _get_inputs(context, node, expected=1)
x = inputs[0]
threshold = 10.39
# Approximating conditional behavior using sigmoid function
sigmoid_threshold = mb.sigmoid(x=mb.sub(x=x, y=threshold))
# Approximate implementation of Softplus
softplus_part = mb.softplus(x=mb.minimum(x=x, y=threshold))
softplus = mb.add(x=mb.mul(x=x, y=sigmoid_threshold),
y=mb.mul(x=softplus_part, y=mb.sub(x=1.0, y=sigmoid_threshold)))
# Mish(x) = x * tanh(Softplus(x))
tanh_softplus = mb.tanh(x=softplus)
res = mb.mul(x=x, y=tanh_softplus, name=node.name)
context.add(res)
For security reasons, I am not able to download and run your network. Please create a minimal example to demonstrate the issue. Ideally some small amount of self contained code that I can just copy and paste.
For security reasons, I am not able to download and run your network. Please create a minimal example to demonstrate the issue. Ideally some small amount of self contained code that I can just copy and paste.
I have created a minimal example to demonstrate the issue below. It is small amount of self contained code so you can just copy and paste.
To Reproduce
Two scripts to reproduce this issue. One uses the built-in Mish activation, and the other uses the alternative Mish implementation.
Built-in Mish Activation
import torch
import torch.nn as nn
import coremltools as ct
import numpy as np
from coremltools.converters.mil.frontend.torch.torch_op_registry import (
_TORCH_OPS_REGISTRY,
register_torch_op,
)
from coremltools.converters.mil.frontend.torch.ops import _get_inputs
from coremltools.converters.mil import Builder as mb
# TO ENABLE THE WORKAROUND MISH FUNCTION, UNCOMMENT THE FOLLOWING LINES OF CODE
# # Remove the original mish function
# if "mish" in _TORCH_OPS_REGISTRY:
# del _TORCH_OPS_REGISTRY["mish"]
# # Register the new mish function
# @register_torch_op
# def mish(context, node):
# inputs = _get_inputs(context, node, expected=1)
# x = inputs[0]
# threshold = 10.39
# # Approximating conditional behavior using sigmoid function
# sigmoid_threshold = mb.sigmoid(x=mb.sub(x=x, y=threshold))
# # Approximate implementation of Softplus
# softplus_part = mb.softplus(x=mb.minimum(x=x, y=threshold))
# softplus = mb.add(
# x=mb.mul(x=x, y=sigmoid_threshold),
# y=mb.mul(x=softplus_part, y=mb.sub(x=1.0, y=sigmoid_threshold)),
# )
# # Mish(x) = x * tanh(Softplus(x))
# tanh_softplus = mb.tanh(x=softplus)
# res = mb.mul(x=x, y=tanh_softplus, name=node.name)
# context.add(res)
# TO ENABLE THE WORKAROUND MISH FUNCTION, UNCOMMENT THE ABOVE LINES OF CODE
class MishModel(nn.Module):
def __init__(self):
super(MishModel, self).__init__()
self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding="same")
self.act = nn.Mish()
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(28 * 28 * 16, 10)
def forward(self, x):
x = self.act(self.conv1(x))
x = self.flatten(x)
x = self.fc1(x)
return x
# Export to Core ML
def export_to_coreml(model, target_path, compute_units):
dummy_input = torch.randn(1, 1, 28, 28)
with torch.no_grad():
model.eval()
traced_model = torch.jit.trace(model, dummy_input)
inputs = [ct.TensorType(shape=tuple(dummy_input.shape))]
mlmodel = ct.convert(
traced_model,
inputs=inputs,
compute_precision=ct.precision.FLOAT16,
minimum_deployment_target=ct.target.iOS15,
compute_units=compute_units,
)
mlmodel.save(target_path)
return mlmodel
def generate_random_inputs(model, batch_size=1):
inputs = torch.randint(
low=0,
high=40,
size=(
batch_size,
1,
28,
28,
),
dtype=torch.float32,
)
return inputs
def get_coreml_outputs(mlmodel, inputs):
try:
predictions = mlmodel.predict(inputs)
return predictions
except Exception as e:
print(f"Error during CoreML model prediction: {e}")
raise
def flatten_outputs(outputs):
flattened = []
if isinstance(outputs, torch.Tensor):
flattened.append(outputs)
elif isinstance(outputs, (tuple, list)):
for item in outputs:
flattened.extend(flatten_outputs(item))
else:
raise TypeError(f"Unsupported output type: {type(outputs)}")
return flattened
def compute_error(torch_outputs, coreml_outputs, output_names):
errors = {}
flattened_torch_outputs = flatten_outputs(torch_outputs)
for idx, torch_output in enumerate(flattened_torch_outputs):
torch_np = torch_output.cpu().numpy()
coreml_key = output_names[idx]
coreml_np = coreml_outputs[coreml_key]
error = np.mean(np.abs(torch_np - coreml_np))
errors[coreml_key] = error
return errors
# Main function
def main():
model = MishModel()
coreml_model_gpu = export_to_coreml(
model,
"model_fp32.mlpackage",
compute_units=ct.ComputeUnit.CPU_AND_GPU,
)
coreml_model_ne = export_to_coreml(
model,
"model_fp16.mlpackage",
compute_units=ct.ComputeUnit.CPU_AND_NE,
)
spec = coreml_model_ne._spec
input_names = [input.name for input in spec.description.input]
output_names = [output.name for output in spec.description.output]
input_name = input_names[0]
num_samples = 30
test_inputs = generate_random_inputs(model, batch_size=num_samples)
errors_ne = {}
errors_gpu = {}
for name in output_names:
errors_ne[name] = []
errors_gpu[name] = []
# Iterate over each sample
for i in range(num_samples):
# Prepare single sample inputs
single_input = test_inputs[i].unsqueeze(0) # Shape: (1, C, H, W)
# Prepare input dictionary for CoreML prediction
input_dict = {
input_name: single_input.numpy(),
}
# Compute PyTorch outputs
with torch.no_grad():
torch_output = model(single_input)
# Ensure torch_output is a tuple
if not isinstance(torch_output, tuple):
torch_output = (torch_output,)
# Compute CoreML outputs
coreml_output_ne = get_coreml_outputs(coreml_model_ne, input_dict)
coreml_output_gpu = get_coreml_outputs(coreml_model_gpu, input_dict)
# Compute errors for each output# Compute errors for each output
error_current_ne = compute_error(torch_output, coreml_output_ne, output_names)
error_current_gpu = compute_error(torch_output, coreml_output_gpu, output_names)
# Accumulate errors
for name in output_names:
if error_current_ne.get(name) is not None:
errors_ne[name].append(error_current_ne[name])
if error_current_gpu.get(name) is not None:
errors_gpu[name].append(error_current_gpu[name])
# Compute mean errors across all samples
mean_errors_ne = {name: np.mean(errors_ne[name]) for name in output_names}
mean_errors_gpu = {name: np.mean(errors_gpu[name]) for name in output_names}
# Display mean errors
print("\nMean Absolute Errors Across Samples:")
for output_name in output_names:
ne_error = mean_errors_ne[output_name]
gpu_error = mean_errors_gpu[output_name]
print(f" {output_name}:")
print(f" NE: {ne_error:.6f}")
print(f" GPU: {gpu_error:.6f}")
if __name__ == "__main__":
main()
Expected Output
Converting PyTorch Frontend ==> MIL Ops: 90%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | 9/10 [00:00<00:00, 2181.88 ops/s]
Running MIL frontend_pytorch pipeline: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 5/5 [00:00<00:00, 5500.01 passes/s]
Running MIL default pipeline: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 88/88 [00:00<00:00, 3854.05 passes/s]
Running MIL backend_mlprogram pipeline: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 12/12 [00:00<00:00, 9894.17 passes/s]
Converting PyTorch Frontend ==> MIL Ops: 90%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | 9/10 [00:00<00:00, 5401.94 ops/s]
Running MIL frontend_pytorch pipeline: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 5/5 [00:00<00:00, 7162.40 passes/s]
Running MIL default pipeline: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 88/88 [00:00<00:00, 4313.92 passes/s]
Running MIL backend_mlprogram pipeline: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 12/12 [00:00<00:00, 9607.11 passes/s]
Mean Absolute Errors Across Samples:
linear_0:
NE: 3.951386
GPU: 0.001291
Alternative Mish Implementation
import torch
import torch.nn as nn
import coremltools as ct
import numpy as np
from coremltools.converters.mil.frontend.torch.torch_op_registry import (
_TORCH_OPS_REGISTRY,
register_torch_op,
)
from coremltools.converters.mil.frontend.torch.ops import _get_inputs
from coremltools.converters.mil import Builder as mb
# TO ENABLE THE WORKAROUND MISH FUNCTION, UNCOMMENT THE FOLLOWING LINES OF CODE
# Remove the original mish function
if "mish" in _TORCH_OPS_REGISTRY:
del _TORCH_OPS_REGISTRY["mish"]
# Register the new mish function
@register_torch_op
def mish(context, node):
inputs = _get_inputs(context, node, expected=1)
x = inputs[0]
threshold = 10.39
# Approximating conditional behavior using sigmoid function
sigmoid_threshold = mb.sigmoid(x=mb.sub(x=x, y=threshold))
# Approximate implementation of Softplus
softplus_part = mb.softplus(x=mb.minimum(x=x, y=threshold))
softplus = mb.add(
x=mb.mul(x=x, y=sigmoid_threshold),
y=mb.mul(x=softplus_part, y=mb.sub(x=1.0, y=sigmoid_threshold)),
)
# Mish(x) = x * tanh(Softplus(x))
tanh_softplus = mb.tanh(x=softplus)
res = mb.mul(x=x, y=tanh_softplus, name=node.name)
context.add(res)
# TO ENABLE THE WORKAROUND MISH FUNCTION, UNCOMMENT THE ABOVE LINES OF CODE
class MishModel(nn.Module):
def __init__(self):
super(MishModel, self).__init__()
self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding="same")
self.act = nn.Mish()
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(28 * 28 * 16, 10)
def forward(self, x):
x = self.act(self.conv1(x))
x = self.flatten(x)
x = self.fc1(x)
return x
# Export to Core ML
def export_to_coreml(model, target_path, compute_units):
dummy_input = torch.randn(1, 1, 28, 28)
with torch.no_grad():
model.eval()
traced_model = torch.jit.trace(model, dummy_input)
inputs = [ct.TensorType(shape=tuple(dummy_input.shape))]
mlmodel = ct.convert(
traced_model,
inputs=inputs,
compute_precision=ct.precision.FLOAT16,
minimum_deployment_target=ct.target.iOS15,
compute_units=compute_units,
)
mlmodel.save(target_path)
return mlmodel
def generate_random_inputs(model, batch_size=1):
inputs = torch.randint(
low=0,
high=40,
size=(
batch_size,
1,
28,
28,
),
dtype=torch.float32,
)
return inputs
def get_coreml_outputs(mlmodel, inputs):
try:
predictions = mlmodel.predict(inputs)
return predictions
except Exception as e:
print(f"Error during CoreML model prediction: {e}")
raise
def flatten_outputs(outputs):
flattened = []
if isinstance(outputs, torch.Tensor):
flattened.append(outputs)
elif isinstance(outputs, (tuple, list)):
for item in outputs:
flattened.extend(flatten_outputs(item))
else:
raise TypeError(f"Unsupported output type: {type(outputs)}")
return flattened
def compute_error(torch_outputs, coreml_outputs, output_names):
errors = {}
flattened_torch_outputs = flatten_outputs(torch_outputs)
for idx, torch_output in enumerate(flattened_torch_outputs):
torch_np = torch_output.cpu().numpy()
coreml_key = output_names[idx]
coreml_np = coreml_outputs[coreml_key]
error = np.mean(np.abs(torch_np - coreml_np))
errors[coreml_key] = error
return errors
# Main function
def main():
model = MishModel()
coreml_model_gpu = export_to_coreml(
model,
"model_fp32.mlpackage",
compute_units=ct.ComputeUnit.CPU_AND_GPU,
)
coreml_model_ne = export_to_coreml(
model,
"model_fp16.mlpackage",
compute_units=ct.ComputeUnit.CPU_AND_NE,
)
spec = coreml_model_ne._spec
input_names = [input.name for input in spec.description.input]
output_names = [output.name for output in spec.description.output]
input_name = input_names[0]
num_samples = 30
test_inputs = generate_random_inputs(model, batch_size=num_samples)
errors_ne = {}
errors_gpu = {}
for name in output_names:
errors_ne[name] = []
errors_gpu[name] = []
# Iterate over each sample
for i in range(num_samples):
# Prepare single sample inputs
single_input = test_inputs[i].unsqueeze(0) # Shape: (1, C, H, W)
# Prepare input dictionary for CoreML prediction
input_dict = {
input_name: single_input.numpy(),
}
# Compute PyTorch outputs
with torch.no_grad():
torch_output = model(single_input)
# Ensure torch_output is a tuple
if not isinstance(torch_output, tuple):
torch_output = (torch_output,)
# Compute CoreML outputs
coreml_output_ne = get_coreml_outputs(coreml_model_ne, input_dict)
coreml_output_gpu = get_coreml_outputs(coreml_model_gpu, input_dict)
# Compute errors for each output# Compute errors for each output
error_current_ne = compute_error(torch_output, coreml_output_ne, output_names)
error_current_gpu = compute_error(torch_output, coreml_output_gpu, output_names)
# Accumulate errors
for name in output_names:
if error_current_ne.get(name) is not None:
errors_ne[name].append(error_current_ne[name])
if error_current_gpu.get(name) is not None:
errors_gpu[name].append(error_current_gpu[name])
# Compute mean errors across all samples
mean_errors_ne = {name: np.mean(errors_ne[name]) for name in output_names}
mean_errors_gpu = {name: np.mean(errors_gpu[name]) for name in output_names}
# Display mean errors
print("\nMean Absolute Errors Across Samples:")
for output_name in output_names:
ne_error = mean_errors_ne[output_name]
gpu_error = mean_errors_gpu[output_name]
print(f" {output_name}:")
print(f" NE: {ne_error:.6f}")
print(f" GPU: {gpu_error:.6f}")
if __name__ == "__main__":
main()
Expected Output
Converting PyTorch Frontend ==> MIL Ops: 90%|ββββββββββββββββββββββββββββββββββββββ | 9/10 [00:00<00:00, 1975.03 ops/s]
Running MIL frontend_pytorch pipeline: 100%|ββββββββββββββββββββββββββββββββββββββββββ| 5/5 [00:00<00:00, 4351.84 passes/s]
Running MIL default pipeline: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββ| 88/88 [00:00<00:00, 2862.23 passes/s]
Running MIL backend_mlprogram pipeline: 100%|βββββββββββββββββββββββββββββββββββββββ| 12/12 [00:00<00:00, 6956.69 passes/s]
Converting PyTorch Frontend ==> MIL Ops: 90%|ββββββββββββββββββββββββββββββββββββββ | 9/10 [00:00<00:00, 4345.93 ops/s]
Running MIL frontend_pytorch pipeline: 100%|ββββββββββββββββββββββββββββββββββββββββββ| 5/5 [00:00<00:00, 5482.75 passes/s]
Running MIL default pipeline: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββ| 88/88 [00:00<00:00, 3141.61 passes/s]
Running MIL backend_mlprogram pipeline: 100%|βββββββββββββββββββββββββββββββββββββββ| 12/12 [00:00<00:00, 7155.48 passes/s]
Mean Absolute Errors Across Samples:
linear_0:
NE: 0.001768
GPU: 0.001291