coremltools icon indicating copy to clipboard operation
coremltools copied to clipboard

ct.models.pipeline.Pipeline gives different predictions than it's models run consecutively

Open mateuszlugowski opened this issue 10 months ago • 5 comments

🐞Describing the bug

The bug became apparent when we were trying to export a detection model to CoreML, using ct.models.pipeline.Pipeline to combine the detector (which outputs raw xywh boxes and confidences) and NMS model. The resulting pipeline model is giving incorrect predictions.

However when we take the pipeline's models separately and run NMS model on the detector predictions the results are correct.

To Reproduce

import tensorflow as tf
import coremltools as ct
from PIL import Image

# Constructing dummy yolo-like model, 
# when we pass a completely black (0's) image through it 
# we will get all the coordinates and confidences equal to 0.5 due to sigmoid(0.)

n_classes = 34
n_anchors = 4032

inputs = tf.keras.Input(shape=(256, 256, 3), name='image')

confidences = tf.keras.layers.GlobalAveragePooling2D()(inputs)
confidences = tf.keras.layers.Dense(n_classes * n_anchors, activation='sigmoid')(confidences)
confidences = tf.keras.layers.Reshape((n_anchors, n_classes))(confidences)
confidences = tf.squeeze(confidences, axis=0)

coordinates = tf.keras.layers.GlobalAveragePooling2D()(inputs)
coordinates = tf.keras.layers.Dense(4 * n_anchors, activation='sigmoid')(coordinates)
coordinates = tf.keras.layers.Reshape((n_anchors, 4))(coordinates)
coordinates = tf.squeeze(coordinates, axis=0)

dummy_detector = tf.keras.Model(inputs=inputs, outputs=[confidences, coordinates])

input_shape = [1] + dummy_detector.input.shape[1:].as_list()

detector_model = ct.convert(
    dummy_detector,
    inputs=[ct.ImageType("image", shape=input_shape)],
    convert_to="mlprogram",
)


# Constructing NMS model

nms_spec = ct.proto.Model_pb2.Model()
nms_spec.specificationVersion = 5

for i in range(2):
    decoder_output = detector_model._spec.description.output[i].SerializeToString()

    nms_spec.description.input.add()
    nms_spec.description.input[i].ParseFromString(decoder_output)

    nms_spec.description.output.add()
    nms_spec.description.output[i].ParseFromString(decoder_output)

nms_spec.description.output[0].name = "confidence"
nms_spec.description.output[1].name = "coordinates"

output_sizes = [n_classes, 4]
for i in range(2):
    ma_type = nms_spec.description.output[i].type.multiArrayType
    ma_type.shapeRange.sizeRanges.add()
    ma_type.shapeRange.sizeRanges[0].lowerBound = 0
    ma_type.shapeRange.sizeRanges[0].upperBound = -1
    ma_type.shapeRange.sizeRanges.add()
    ma_type.shapeRange.sizeRanges[1].lowerBound = output_sizes[i]
    ma_type.shapeRange.sizeRanges[1].upperBound = output_sizes[i]
    del ma_type.shape[:]

nms = nms_spec.nonMaximumSuppression
nms.confidenceInputFeatureName = detector_model._spec.description.output[0].name
nms.coordinatesInputFeatureName = detector_model._spec.description.output[1].name
nms.confidenceOutputFeatureName = "confidence"
nms.coordinatesOutputFeatureName = "coordinates"
nms.iouThresholdInputFeatureName = "iouThreshold"
nms.confidenceThresholdInputFeatureName = "confidenceThreshold"

nms.iouThreshold = 0.35
nms.confidenceThreshold = 0.1

nms.pickTop.perClass = True

nms_model = ct.models.MLModel(nms_spec)


# Constructing pipeline model

pipeline_model = ct.models.utils.make_pipeline(detector_model, nms_model)


# Predictions

img = Image.new(mode='RGB', size=(256, 256), color=0)
print("Pipeline predictions:")
# There should be one box only and no zeros in confidences (as none are in the NMS input)
print(pipeline_model.predict({'image': img}))

detector_from_pipeline = ct.models.MLModel(
    pipeline_model._spec.pipeline.models[0],
    weights_dir=pipeline_model.weights_dir
)

nms_from_pipeline = ct.models.MLModel(
    pipeline_model._spec.pipeline.models[1]
)

print("Individual models consecutive prediction:")
# Correct behavior
print(nms_from_pipeline.predict(detector_from_pipeline.predict({'image': img})))

Output (both pipeline predictions and individual models' predictions should be the same):

Pipeline predictions:

{'coordinates': array(
       [[0.5, 0.5, 0.5, 0.5],
       [0.5, 0.5, 0.5, 0.5],
       [0.5, 0.5, 0.5, 0.5],
       [0.5, 0.5, 0.5, 0.5],
       [0.5, 0.5, 0.5, 0.5],
       [0.5, 0.5, 0.5, 0.5],
       [0.5, 0.5, 0.5, 0.5],
       [0.5, 0.5, 0.5, 0.5]],
       dtype=float32),
 'confidence': array([[0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5,
        0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0. , 0. , 0. , 0. , 0. , 0. ,
        0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
       [0. , 0. , 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5,
        0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5,
        0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5],
       [0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.5, 0.5, 0.5,
        0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5,
        0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5],
       [0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
        0. , 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5,
        0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5],
       [0. , 0. , 0. , 0. , 0. , 0. , 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5,
        0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5,
        0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5],
       [0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.5, 0.5, 0.5, 0.5, 0.5,
        0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5,
        0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5],
       [0. , 0. , 0. , 0. , 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5,
        0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5,
        0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5],
       [0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.5,
        0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5,
        0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]], dtype=float32)}


Individual models consecutive prediction:

{'coordinates': array([[0.5, 0.5, 0.5, 0.5]], dtype=float32), 
'confidence': array([[0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5,
        0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5,
        0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]], dtype=float32)}

System environment (please complete the following information):

  • coremltools version: 8.2 (also happens on 7.2)
  • OS (e.g. MacOS version or Linux type): MacOS 15.1
  • Any other relevant version information (e.g. PyTorch or TensorFlow version): Tensorflow 2.12.0 (also happens on 2.13.0)

Summary:

  • If we use full model as a pipeline: pipeline_model.predict({'image': img}) we get incorrect results
  • However, if we split the pipeline into two calls: nms_from_pipeline.predict(detector_from_pipeline.predict({'image': img})) the results are correct.
  • Also we noticed that for smaller number n_classes e.g. 1, 2 the code works as expected
  • Question: are we using NMS correctly? We weren't able to find any explicit documentation, so we based this on how it is implemented in other repos (https://github.com/cloud-annotations/cloud-annotations/blob/main/training/scripts/trainer/src/convert/build_nms.py, https://github.com/hietalajulius/yolov5/blob/1023da95a54466cc320d79cc0408ea8b171d0321/export-nms.py#L237)

mateuszlugowski avatar Feb 10 '25 16:02 mateuszlugowski

The max version of tensorflow that we support is 2.12.0. Also coremltools 7.2 is about nine months old. Is this still an issue with the latest version of coremltools and a version of TensorFlow that we support?

TobyRoseman avatar Feb 11 '25 18:02 TobyRoseman

I just retested it on tensorflow 2.12.0 and coremltools 8.2. Same issue is present.

mateuszlugowski avatar Feb 12 '25 14:02 mateuszlugowski

hey @TobyRoseman any update on this ?

kmkolasinski avatar Feb 17 '25 06:02 kmkolasinski

No major update. I was able to reproduce the issue (hence the "triaged" tag) but haven't been able to dedicate much time to it.

I don't see anything clearly wrong with the code. Although it is accessing a private member variable (i.e. ._spec) would be better to call .get_spec() and use that.

Next step should probably be to edit the pipeline model so it outputs all outputs from the first model. That should help narrow down the problem. Feel free to work on that.

TobyRoseman avatar Feb 18 '25 00:02 TobyRoseman

Hey, thanks for the reply. We tried to solve this issue on our side but w/o success :( here I few things we noticed and may help you:

  1. Adding identity layer between Model and NMS solved the problem in python and xcode preview tool, but our swift mobile App which uses VNRequestHandler was crashing, here is the snippet what we tried:
    from coremltools.converters.mil import Builder as mb

    @mb.program(
        input_specs=[
            mb.TensorSpec(shape=(num_boxes, num_classes)),
            mb.TensorSpec(shape=(num_boxes, 4)),
        ]
    )
    # Names Identity and Identity_1 are coming from converted CoreML detector, and refer to confidence and coordinates
    def prog(Identity, Identity_1):
        "Simple identity model."
        rawConfidences = mb.identity(x=Identity, name="rawConfidences")
        rawCoordinates = mb.identity(x=Identity_1, name="rawCoordinates")

        return rawConfidences, rawCoordinates

    identity_model = ct.convert(prog)

# something like this was solving the problem in python and XCode preview tool, but crashes in swift mobile app
...
 pipeline.add_model(model)
 pipeline.add_model(identity_model)
 pipeline.add_model(nms_model)
...
  1. From what I remember we had some issues with the public get_spec which was not working as expected, but maybe this is connected. Also all public repositories which are using NMS from coremltools use this approach (we just copied others). The question is: Is there some official doc on how to use it ?

  2. Issue appears when there are more classes n_classes and it's consistent, which would suggest some bad shape information or memory issues.

kmkolasinski avatar Feb 18 '25 06:02 kmkolasinski