coremltools icon indicating copy to clipboard operation
coremltools copied to clipboard

Converting Faster-RCNN from PyTorch to CoreML

Open gizzleon opened this issue 9 months ago β€’ 5 comments

🐞Describing the bug

Hi, I am converting a PyTorch Faster R-CNN model to CoreML and encountered data type mismatching issue, which may be related to https://github.com/apple/coremltools/issues/2440

The model I'm converting is torchvision.models.detection.faster_rcnn.fasterrcnn_resnet50_fpn_v2.

The first issue was the unsupported torchvision::roi_align operator. With the implementation from this PR, I was able to convert a single RoIAlign layer.

However, when converting the whole Faster R-CNN model, the second input variable rois has unexpected shape (0,1) and dtype int32, where it is supposed to be a (N,5) float tensor.

Stack Trace

ERROR - converting 'torchvision::roi_align' op (located at: 'network/roi_heads/box_roi_pool/result_idx_in_level.1'):

Converting PyTorch Frontend ==> MIL Ops:  81%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ– | 1374/1686 [00:00<00:00, 6381.85 ops/s]
Traceback (most recent call last):
  File "./bug_report.py", line 104, in <module>
    convert_faster_rcnn_model()
  File "./bug_report.py", line 101, in convert_faster_rcnn_model
    ct.convert(traced_model, inputs=[ct.TensorType(name="Input", shape=input_.shape)])
  File "./venv/lib/python3.12/site-packages/coremltools/converters/_converters_entry.py", line 635, in convert
    mlmodel = mil_convert(
              ^^^^^^^^^^^^
  File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/converter.py", line 188, in mil_convert
    return _mil_convert(model, convert_from, convert_to, ConverterRegistry, MLModel, compute_units, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/converter.py", line 212, in _mil_convert
    proto, mil_program = mil_convert_to_proto(
                         ^^^^^^^^^^^^^^^^^^^^^
  File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/converter.py", line 288, in mil_convert_to_proto
    prog = frontend_converter(model, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/converter.py", line 108, in __call__
    return load(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/frontend/torch/load.py", line 88, in load
    return _perform_torch_convert(converter, debug)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/frontend/torch/load.py", line 151, in _perform_torch_convert
    prog = converter.convert()
           ^^^^^^^^^^^^^^^^^^^
  File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/frontend/torch/converter.py", line 1387, in convert
    convert_nodes(self.context, self.graph, early_exit=not has_states)
  File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/frontend/torch/ops.py", line 116, in convert_nodes
    raise e     # re-raise exception
    ^^^^^^^
  File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/frontend/torch/ops.py", line 111, in convert_nodes
    convert_single_node(context, node)
  File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/frontend/torch/ops.py", line 175, in convert_single_node
    add_op(context, node)
  File "./bug_report.py", line 46, in roi_align
    x = mb.crop_resize(
        ^^^^^^^^^^^^^^^
  File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/mil/ops/registry.py", line 183, in add_op
    return cls._add_op(op_cls_to_add, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/mil/builder.py", line 217, in _add_op
    new_op = op_cls(**kwargs)
             ^^^^^^^^^^^^^^^^
  File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/mil/operation.py", line 195, in __init__
    self._validate_and_set_inputs(input_kv)
  File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/mil/operation.py", line 511, in _validate_and_set_inputs
    self.input_spec.validate_inputs(self.name, self.op_type, input_kvs)
  File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/mil/input_type.py", line 138, in validate_inputs
    raise ValueError(msg)
ValueError: In op, of type crop_resize, named crop_resize_0, the named input `roi` must have the same data type as the named input `x`. However, roi has dtype int32 whereas x has dtype fp32.

To Reproduce

import coremltools as ct
import torch
from coremltools.converters.mil.frontend.torch.ops import _get_inputs
from coremltools.converters.mil.frontend.torch.torch_op_registry import (
    register_torch_op,
)
from coremltools.converters.mil.mil import Builder as mb
from torchvision.models.detection.faster_rcnn import fasterrcnn_resnet50_fpn_v2

from torchvision.ops.roi_align import RoIAlign


@register_torch_op(torch_alias=["torchvision::roi_align"])
def roi_align(context, node):
    inputs = _get_inputs(context, node)

    x = context[node.inputs[0]]
    input_shape = x.shape  # (B, h_in, w_in, C)
    if len(input_shape) != 4:
        raise ValueError(
            '"CropResize" op: expected input rank 4, got {}'.format(x.rank)
        )

    const_box_info = True
    if context[node.inputs[1]].val is None or context[node.inputs[2]].val is None:
        const_box_info = False

    extrapolation_value = context[node.inputs[2]].val

    # CoreML index information along with boxes
    if const_box_info:
        boxes = context[node.inputs[1]].val
        # CoreML expects boxes/ROI in
        # [N, 1, 5, 1, 1] format
        boxes = boxes.reshape(boxes.shape[0], 1, boxes.shape[1], 1, 1)
    else:
        boxes = inputs[1]
        boxes = mb.reshape(x=boxes, shape=[boxes.shape[0], 1, boxes.shape[1], 1, 1])
    # Get Height and Width of crop
    h_out = inputs[3]
    w_out = inputs[4]

    # Torch input format: [B, C, h_in, w_in]
    # CoreML input format: [B, C, h_in, w_in]

    # Crop Resize
    x = mb.crop_resize(
        x=x,
        roi=boxes,
        target_height=h_out.val,
        target_width=w_out.val,
        normalized_coordinates=True,
        spatial_scale=extrapolation_value,
        box_coordinate_mode="CORNERS_HEIGHT_FIRST",
        sampling_mode="OFFSET_CORNERS",
    )

    # CoreML output format: [N, 1, C, h_out, w_out]
    # Torch output format: [N, C, h_out, w_out]
    x = mb.squeeze(x=x, axes=[1])

    context.add(x, torch_name=node.outputs[0])


def convert_roi_align_layer():
    roi_align_layer = RoIAlign(
        output_size=(7, 7), spatial_scale=1.0, sampling_ratio=1, aligned=False
    )

    input_tensor = torch.randn((1, 3, 400, 800))
    rois_stacked = torch.FloatTensor([[0, 0, 0, 10, 10], [0, 5, 5, 20, 20]])

    roi_align_layer.eval()

    traced_model = torch.jit.trace(roi_align_layer, (input_tensor, rois_stacked))

    ct.convert(
        traced_model,
        inputs=[
            ct.TensorType(name="Input", shape=input_tensor.shape),
            ct.TensorType(name="Rois", shape=rois_stacked.shape),
        ],
    )


def convert_faster_rcnn_model():
    model = fasterrcnn_resnet50_fpn_v2(pretrained=False)

    class ModelWrapper(torch.nn.Module):
        def __init__(self, network: torch.nn.Module):
            super().__init__()
            self.network = network

        def forward(self, x):
            output = self.network(x)[0]
            return output["boxes"], output["labels"], output["scores"]

    wrapped_model = ModelWrapper(model)

    input_ = torch.randn((1, 3, 400, 800))
    wrapped_model.eval()

    traced_model = torch.jit.trace(wrapped_model, input_)

    ct.convert(traced_model, inputs=[ct.TensorType(name="Input", shape=input_.shape)])


convert_roi_align_layer()
convert_faster_rcnn_model()

System environment:

  • coremltools version: 8.1
  • OS (e.g. MacOS version or Linux type): MacOS 15.3.2
  • Any other relevant version information (e.g. PyTorch or TensorFlow version):
    • torch==2.5.1
    • torchvision==0.20.1
    • numpy==1.26.4

gizzleon avatar Apr 04 '25 14:04 gizzleon

@gizzleon just out of curiosity: coremltools, torch, and torchvision had newer releases in January. Is this report older than 2 weeks, could you try with current dependencies again?

And when I install coremltools, numpy 2.2 is being used, wasn't that the case for you?

reneleonhardt avatar Apr 17 '25 08:04 reneleonhardt

@reneleonhardt The issue persists on newer versions of coremltools 8.2, torch 2.6.0 and torchvision 0.21.0.

I am not able to use numpy 2.x as coremltools/converters/mil/mil/ops/defs/iOS15/elementwise_unary.py uses a copy operation deprecated in 2.x.

gizzleon avatar Apr 17 '25 09:04 gizzleon

I can't find any deprecations in 2.0, 2.1 or 2.2 regarding copy or these two function calls. https://numpy.org/doc/stable/release/2.0.0-notes.html#deprecations

Release notes say numpy 2 is supported: https://github.com/apple/coremltools/releases/tag/8.0

If you have time maybe you can open another issue for your environment πŸ™‚

reneleonhardt avatar Apr 17 '25 11:04 reneleonhardt

I can't find any deprecations in 2.0, 2.1 or 2.2 regarding copy or these two function calls. https://numpy.org/doc/stable/release/2.0.0-notes.html#deprecations

Release notes say numpy 2 is supported: https://github.com/apple/coremltools/releases/tag/8.0

If you have time maybe you can open another issue for your environment πŸ™‚

It is a behavior change on the copy keyword rather than a deprecation. Sorry for the confusion.

The log I got with numpy 2.2.4:

  File "venv/lib/python3.12/site-packages/coremltools/converters/mil/frontend/torch/ops.py", line 112, in convert_nodes
    convert_single_node(context, node)
  File "venv/lib/python3.12/site-packages/coremltools/converters/mil/frontend/torch/ops.py", line 173, in convert_single_node
    add_op(context, node)
  File "venv/lib/python3.12/site-packages/coremltools/converters/mil/frontend/torch/ops.py", line 6992, in reciprocal
    context.add(mb.inverse(x=inputs[0], name=node.name))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "venv/lib/python3.12/site-packages/coremltools/converters/mil/mil/ops/registry.py", line 183, in add_op
    return cls._add_op(op_cls_to_add, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "venv/lib/python3.12/site-packages/coremltools/converters/mil/mil/builder.py", line 237, in _add_op
    new_op.type_value_inference()
  File "venv/lib/python3.12/site-packages/coremltools/converters/mil/mil/operation.py", line 265, in type_value_inference
    output_vals = self._auto_val(output_types)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "venv/lib/python3.12/site-packages/coremltools/converters/mil/mil/operation.py", line 382, in _auto_val
    vals = self.value_inference()
           ^^^^^^^^^^^^^^^^^^^^^^
  File "venv/lib/python3.12/site-packages/coremltools/converters/mil/mil/operation.py", line 111, in wrapper
    return func(self)
           ^^^^^^^^^^
  File "venv/lib/python3.12/site-packages/coremltools/converters/mil/mil/ops/defs/iOS15/elementwise_unary.py", line 449, in value_inference
    return np.array(np.reciprocal(self.x.val + self.epsilon.val), copy=False)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: Unable to avoid copy while creating an array as requested.
If using `np.array(obj, copy=False)` replace it with `np.asarray(obj)` to allow a copy when needed (no behavior change in NumPy 1.x).
For more details, see https://numpy.org/devdocs/numpy_2_0_migration_guide.html#adapting-to-changes-in-the-copy-keyword.

Unfortunately I couldn't encapsulate and create a smaller sample for reproduction. The issue goes away when I'm converting a minimal network with torch.reciprocal

gizzleon avatar Apr 17 '25 12:04 gizzleon

Interesting, now I can see what you mean, thank you! This was the only copy=False so I created https://github.com/apple/coremltools/pull/2488 to migrate to np.asarray(), maybe it will be merged 🀞

reneleonhardt avatar Apr 17 '25 13:04 reneleonhardt