CoreML fails at runtime with int32 torch.mm
🐞Describing the bug
torch.mm fails at runtime in CoreML.
Stack Trace
/opt/miniconda3/envs/op-et/lib/python3.10/site-packages/coremltools/models/model.py:560: RuntimeWarning: You will not be able to run predict() on this Core ML model. Underlying exception message was: {
NSLocalizedDescription = "Failed to build the model execution plan using a model architecture file '/private/var/folders/lw/phxpy6k10ll388xs18hyq1cr0000gn/T/tmpcb8uw0e1.mlmodelc/model.mil' with error code: -14.";
}
_warnings.warn(
Traceback (most recent call last):
File "/Users/scroy/Desktop/executorch/test.py", line 159, in <module>
out = mlmodel.predict(predict_inputs)
File "/opt/miniconda3/envs/op-et/lib/python3.10/site-packages/coremltools/models/model.py", line 804, in predict
raise self._framework_error
File "/opt/miniconda3/envs/op-et/lib/python3.10/site-packages/coremltools/models/model.py", line 549, in _get_proxy_and_spec
_MLModelProxy(
RuntimeError: {
NSLocalizedDescription = "Failed to build the model execution plan using a model architecture file '/private/var/folders/lw/phxpy6k10ll388xs18hyq1cr0000gn/T/tmpcb8uw0e1.mlmodelc/model.mil' with error code: -14.";
To Reproduce
import torch
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.weight = torch.randint(0, 100, (8, 8)).to(torch.int32)
def forward(self, x):
return torch.mm(x, self.weight)
model = Model()
inputs = (
torch.randn(8, 8).to(torch.int32),
)
eager_outputs = model(*inputs)
ep = torch.export.export(model, inputs)
print(ep)
import coremltools as ct
import numpy as np
ep = ep.run_decompositions({})
eager_outputs = model(*inputs)
mlmodel = ct.convert(ep)
coreml_inputs = mlmodel.get_spec().description.input
coreml_outputs = mlmodel.get_spec().description.output
predict_inputs = {str(ct_in.name): pt_in.detach().cpu().numpy().astype(np.int32) for ct_in, pt_in in zip(coreml_inputs, inputs)}
out = mlmodel.predict(predict_inputs)
print("Eager", eager_outputs)
print("CoremL", out)
System environment (please complete the following information):
- coremltools version: 8.3
- OS (e.g. MacOS version or Linux type): macOS15
def _construct_matmul(x: Var, y: Var, name: Optional[str] = None) -> Var:
if (len(y.shape) == 2 and len(x.shape) <= 3) and (_is_const(y) or y.is_descendant_of_const):
linear_x, weight = x, y
transposed_weight = mb.transpose(x=weight, perm=(1, 0))
res = mb.linear(x=linear_x, weight=transposed_weight, name=name)
else:
x, y = promote_input_dtypes([x, y])
res = mb.matmul(x=x, y=y, name=name)
return res
@register_torch_op(torch_alias=["bmm", "mm"])
def matmul(context, node):
x, y = _get_inputs(context, node, expected=2)
res = _construct_matmul(x, y, node.name)
context.add(res)
There is an mb.linear bug.
macOS: 15.6 coremltools: Built from source
import torch
import coremltools as ct
from coremltools.converters.mil import Builder as mb
from coremltools.converters.mil.mil import types
import numpy as np
x0 = torch.randint(0, 100, (8, 8), dtype=torch.int32)
x1 = torch.randint(0, 100, (8, 8), dtype=torch.int32)
y0 = torch.mm(x0, x1)
@mb.program(input_specs=[
mb.TensorSpec(shape=x0.shape, dtype=types.int32),
])
def mm(x):
w = x1.transpose(0, 1).numpy()
return mb.linear(x=x, weight=mb.const(val=w), name="y")
mlmodel = ct.convert(mm, convert_to="neuralnetwork")
y1 = mlmodel.predict({
"x": x0.numpy(),
})["y"]
print(np.array_equal(y0.numpy(), y1.astype(np.int32))) # False
mlmodel = ct.convert(mm, convert_to="mlprogram")
y2 = mlmodel.predict({
"x": x0.numpy(),
})["y"]
print(np.array_equal(y0.numpy(), y2.astype(np.int32))) # ERROR
@M-Quadra - could please you put up a PR with your fix and unit test?
@TobyRoseman
Using mb.matmul may avoid this RuntimeError. But other unit tests will fail in torch 2.8.0.
The CI won't be green. _(ˊཀˋ」∠)_
For example:
pytest --disable-warnings -x coremltools/converters/mil/frontend/torch/test/test_torch_ops.py::TestIndexPut::test_index_put_negative_indices_case_2
pytest --disable-warnings -x coremltools/converters/mil/frontend/torch/test/test_torch_ops.py::TestIndexPut::test_index_put_updates_bool
I don't understand. What does this have to do with torch 2.8.0?
For this issue:
Use mb.matmul:
@register_torch_op(torch_alias=["bmm", "mm"])
def matmul(context, node):
x, y = _get_inputs(context, node, expected=2)
x, y = promote_input_dtypes([x, y])
res = mb.matmul(x=x, y=y, name=node.name)
context.add(res)
This implementation successfully passes the new unit test on my local device:
@pytest.mark.parametrize(
"compute_unit, backend, frontend, shape, dtype",
itertools.product(
compute_units,
backends,
frontends,
[(1, 1, 1), (1, 3, 5), (3, 3, 3), (5, 3, 1), (5, 5, 5)],
[torch.float16, torch.float32, torch.int32],
),
)
def test_matmul(self, compute_unit, backend, frontend, shape, dtype):
shape0, shape1 = (shape[0], shape[1]), (shape[1], shape[2])
x0 = torch.randn(shape0, dtype=dtype) if dtype != torch.int32 else torch.randint(-9, 9, shape0, dtype=dtype)
x1 = torch.randn(shape1, dtype=dtype) if dtype != torch.int32 else torch.randint(-9, 9, shape1, dtype=dtype)
model = ModuleWrapper(function=torch.mm)
self.run_compare_torch(
(x0, x1),
model,
compute_unit=compute_unit,
backend=backend,
frontend=frontend,
input_as_shape=False,
minimum_deployment_target=ct.target.iOS16 if dtype == torch.float16 else None,
)
For all unit tests:
https://github.com/apple/coremltools/blob/be33582e2fd62885f15d22338cac9f777ad8119f/scripts/env_create.sh#L98-L99
https://github.com/apple/coremltools/blob/be33582e2fd62885f15d22338cac9f777ad8119f/reqs/test.pip#L2
https://github.com/apple/coremltools/blob/be33582e2fd62885f15d22338cac9f777ad8119f/reqs/pytorch.pip#L8
The current test environment installs torch 2.8.0, which causes existing unit tests to fail (verified on my local device).
pytest --disable-warnings -x coremltools/converters/mil/frontend/torch/test/test_torch_ops.py
It seems that the torch version (CI failed) is another prerequisite issue that needs to be resolved.