coremltools icon indicating copy to clipboard operation
coremltools copied to clipboard

ValueError: C_in / groups = 1/1 != weight[1] (3) 🐞

Open ptmdmusique opened this issue 2 years ago β€’ 11 comments

🐞Describe the bug

  • I'm trying to convert a detectron2 PyTorch model to coreml following this tutorial, but got C_in / groups = 1/1 != weight[1] (3) I successfully converted the detectron2 model into PyTorch but was not able to convert the result to coreml
  • I'm using Unified Conversion API 5.1.0

Trace

WARNING:root:Tuple detected at graph output. This will be flattened in the converted model.
Converting Frontend ==> MIL Ops:   0%|          | 0/1817 [00:00<?, ? ops/s]WARNING:root:Saving value type of int64 into a builtin type of int32, might lose precision!
Converting Frontend ==> MIL Ops:   4%|▍         | 75/1817 [00:00<00:00, 1957.08 ops/s]
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/tmp/ipykernel_1488/3792607256.py in <module>
     17 mlmodel = ct.convert(
     18   traced_script_module,
---> 19   inputs=[ct.TensorType(shape=(1, 3, 64, 64))],
     20 )

/opt/conda/lib/python3.7/site-packages/coremltools/converters/_converters_entry.py in convert(model, source, inputs, outputs, classifier_config, minimum_deployment_target, convert_to, compute_precision, skip_model_load, compute_units, useCPUOnly, package_dir, debug)
    335         compute_units=compute_units,
    336         package_dir=package_dir,
--> 337         debug=debug,
    338     )
    339 

/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/converter.py in mil_convert(model, convert_from, convert_to, compute_units, **kwargs)
    180         See `coremltools.converters.convert`
    181     """
--> 182     return _mil_convert(model, convert_from, convert_to, ConverterRegistry, MLModel, compute_units, **kwargs)
    183 
    184 

/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/converter.py in _mil_convert(model, convert_from, convert_to, registry, modelClass, compute_units, **kwargs)
    212                             convert_to,
    213                             registry,
--> 214                             **kwargs
    215                          )
    216 

/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/converter.py in mil_convert_to_proto(model, convert_from, convert_to, converter_registry, **kwargs)
    298     frontend_converter = frontend_converter_type()
    299 
--> 300     prog = frontend_converter(model, **kwargs)
    301 
    302     if convert_to.lower() != "neuralnetwork":

/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/converter.py in __call__(self, *args, **kwargs)
    102         from .frontend.torch import load
    103 
--> 104         return load(*args, **kwargs)
    105 
    106 

/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/frontend/torch/load.py in load(model_spec, debug, **kwargs)
     48     cut_at_symbols = kwargs.get("cut_at_symbols", None)
     49     converter = TorchConverter(torchscript, inputs, outputs, cut_at_symbols)
---> 50     return _perform_torch_convert(converter, debug)
     51 
     52 

/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/frontend/torch/load.py in _perform_torch_convert(converter, debug)
     85 def _perform_torch_convert(converter, debug):
     86     try:
---> 87         prog = converter.convert()
     88     except RuntimeError as e:
     89         if debug and "convert function" in str(e):

/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/frontend/torch/converter.py in convert(self)
    237 
    238             # Add the rest of the operations
--> 239             convert_nodes(self.context, self.graph)
    240 
    241             graph_outputs = [self.context[name] for name in self.graph.outputs]

/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/frontend/torch/ops.py in convert_nodes(context, graph)
     74                 "PyTorch convert function for op '{}' not implemented.".format(node.kind)
     75             )
---> 76         add_op(context, node)
     77 
     78         # We've generated all the outputs the graph needs, terminate conversion.

/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/frontend/torch/ops.py in _convolution(context, node)
    669     else:
    670         # Normal convolution
--> 671         conv = mb.conv(**kwargs)
    672     context.add(conv)
    673 

/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/mil/ops/registry.py in add_op(cls, **kwargs)
     61             @classmethod
     62             def add_op(cls, **kwargs):
---> 63                 return cls._add_op(op_cls, **kwargs)
     64 
     65             setattr(Builder, op_type, add_op)

/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/mil/builder.py in _add_op(cls, op_cls, **kwargs)
    189         curr_block()._insert_op_before(new_op, before_op=before_op)
    190         new_op.build_nested_blocks()
--> 191         new_op.type_value_inference()
    192         if len(new_op.outputs) == 1:
    193             return new_op.outputs[0]

/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/mil/operation.py in type_value_inference(self, overwrite_output)
    238         existing _output_vars
    239         """
--> 240         output_types = self.type_inference()
    241         if not isinstance(output_types, tuple):
    242             output_types = (output_types,)

/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/mil/ops/defs/conv.py in type_inference(self)
    168         if C_in // groups != self.weight.shape[1]:
    169             msg = "C_in / groups = {}/{} != weight[1] ({})"
--> 170             raise ValueError(msg.format(C_in, groups, self.weight.shape[1]))
    171 
    172         strides = self.strides.val

ValueError: C_in / groups = 1/1 != weight[1] (3)

To Reproduce

# Setup detectron2 logger
import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()

# import some common detectron2 utilities
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog, DatasetCatalog

from detectron2.modeling import build_model
TRAIN_SET = "my_train"
TEST_SET = "my_test"

cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file(model_type))
cfg.DATASETS.TRAIN = (TRAIN_SET,)
cfg.DATASETS.TEST = ()
cfg.DATALOADER.NUM_WORKERS = 2
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(model_type)  # Let training initialize from model zoo
cfg.SOLVER.IMS_PER_BATCH = 2
cfg.SOLVER.BASE_LR = 0.00025  # pick a good LR
cfg.SOLVER.MAX_ITER = 600
cfg.SOLVER.STEPS = []        # do not decay learning rate
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128   # faster, and good enough for this toy dataset (default: 512)
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1  # only has one class

cfg.OUTPUT_DIR = "./output"
import coremltools as ct
import torch
import torchvision
from detectron2.export.flatten import TracingAdapter

# https://stackoverflow.com/questions/66303312/how-to-convert-custom-pytorch-model-to-torchscript-pth-to-pt-model
def inference_func(model, image):
    inputs = [{"image": image}]
    return model.inference(inputs, do_postprocess=False)[0]

model = build_model(cfg)
example = torch.rand(3, 224, 224)
wrapper = TracingAdapter(model, example, inference_func)
wrapper.eval()
traced_script_module = torch.jit.trace(wrapper, (example,))

mlmodel = ct.convert(
  traced_script_module,
  inputs=[ct.TensorType(shape=(1, 3, 64, 64))],
)

System environment (please complete the following information):

  • coremltools version (e.g., 3.0b5): 5.1
  • OS (e.g., MacOS, Linux): GCP Deep Learning Colab
  • How you install python (anaconda, virtualenv, system): Colab's pip install

Additional context

Detectron2: link Detectron2's TracingAdapter: link

ptmdmusique avatar Nov 12 '21 04:11 ptmdmusique

This seems at least some what related to #1322. Both are trying to convert some type of Detectron2 model. However the errors are different.

@ptmdmusique - What is cfg? That is not defined in your code. I'm not going to be able to help you if I can't reproduce the problem.

TobyRoseman avatar Nov 12 '21 22:11 TobyRoseman

Thanks for the fast response, I just updated the post with the import and config cells

Do you have a portal or email that I can send the model to you to test?

ptmdmusique avatar Nov 12 '21 23:11 ptmdmusique

Is there any update on detectron2 to coreML conversion. I am also having the similar issue:

RuntimeError: 
object has no attribute nms:
  File "/home/ubuntu/anaconda3/envs/ai_env/lib/python3.9/site-packages/torchvision/ops/boxes.py", line 42
    """
    _assert_has_ops()
    return torch.ops.torchvision.nms(boxes, scores, iou_threshold)
           ~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
'nms' is being compiled since it was called from 'batched_nms'
  File "/home/ubuntu/anaconda3/envs/ai_env/lib/python3.9/site-packages/torchvision/ops/boxes.py", line 88
        offsets = idxs.to(boxes) * (max_coordinate + torch.tensor(1).to(boxes))
        boxes_for_nms = boxes + offsets[:, None]
        keep = nms(boxes_for_nms, scores, iou_threshold)
        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
        return keep

Here is my code sample to reproduce:

from detectron2.config import get_cfg
from detectron2.export.flatten import TracingAdapter
from detectron2.modeling import build_model

def inference(model, image):
    inputs = [{"image": image}]
    output = model.inference(inputs, do_postprocess=False)[0]
    return output

modelFile = "..." # torch model file, generated by detectron2 framework
inputShape = (3, 512, 512)
inputTensor = torch.rand(inputShape)

config = get_cfg()
config.MODEL.WEIGHTS = modelFile
config.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7 
config.MODEL.DEVICE = "cuda"

model = build_model(config)
traceableModel = TracingAdapter(model, inputTensor, inference).eval()
trace = torch.jit.trace(traceableModel, (inputTensor,))

mlmodel = ct.convert(trace, inputs=[ct.ImageType(shape=inputTensor.shape, scale=1.0/255.0, bias=[0,0,0])]) # Error in this line

I can also share my ".pth" file

Library Versions:

torch==1.9.0 coremltools==5.1.0 detectron2==0.6+cu102

DenizD avatar Dec 02 '21 11:12 DenizD

@TobyRoseman any update on the issue

ptmdmusique avatar Jan 06 '22 05:01 ptmdmusique

I have the same problem. It's seems like the convolution shape is wrong???

Becvause coremltools/converters/mil/mil/ops/defs/conv.py raise input type error, doesn'T it?

john-rocky avatar Jan 16 '22 03:01 john-rocky

Here's a very minimal example which I hope helps in pinpointing the error. It's a single layer model with only one standard PyTorch convolution. My hunch is that the input channels are incorrectly inferred by CoreML since they are dynamically calculated within the model's init, but I'm not sure yet

import coremltools as ct
import torch
from torch import nn


class NetMin(nn.Module):
    """Minimal architecture to reproduce CoreML export error"""

    def __init__(self, crop_height=288, crop_width=512, out_channels=64, grayscale=False, num_frames=3):
        super().__init__()

        # Transfer parameters to attributes
        self.crop_height = crop_height
        self.crop_width = crop_width

        # Determine number of input channels to the encoder
        channels_per_frame = 1 if grayscale else 3
        self.in_channels = channels_per_frame * num_frames
        self.encoder = nn.Conv2d(self.in_channels, out_channels, kernel_size=3)

    def forward(self, imgs):
        return self.encoder(imgs)

    def to_coreml(self, path='tracknet.mlmodel'):
        """Reproduces ValueError: C_in / groups = 290/1 != weight[1] (3)"""

        # Generate sequence of 3 frames (either grayscale or RGB depending on init params)
        example_input_array = torch.randn(1, self.in_channels, self.crop_height, self.crop_width)

        # Run it through tracing + conversion
        self.eval()
        traced_model = torch.jit.trace(self, example_input_array)
        inputs = [ct.TensorType(shape=inp.shape) for inp in example_input_array]
        mlmodel = ct.convert(traced_model, inputs=inputs)
        mlmodel.save(path)


model = NetMin(grayscale=True)  # Same issue for grayscale=False
model.to_coreml()

Versions tested

  • coremltools 5.1.0
  • torch 1.9.1 and 1.10.0

addisonklinke avatar Jan 27 '22 17:01 addisonklinke

Figured out my issue (illustrated below with a conv layer). Unfortunately it's not the same as Detectron because the code by @ptmdmusique already uses the correct inputs_single approach

import coremltools as ct
import torch
from torch import nn


model = nn.Conv2d(3, 64, 3)
example_input_array = torch.randn(1, 3, 288, 512)
model.eval()
traced_model = torch.jit.trace(model, example_input_array)

# Wrong - iteration removes the first (i.e. batch) axis of ``example_input_array``
# This means coremltools will see shape [C, H, W] and assume C_in = shape[1] (which is H, not C)
inputs_array = [ct.TensorType(shape=inp.shape) for inp in example_input_array]
mlmodel = ct.convert(traced_model, inputs=inputs_array)

# Correct (option 1) - preserve full shape and wrap in list
inputs_single = [ct.TensorType(shape=example_input_array.shape)]
mlmodel = ct.convert(traced_model, inputs=inputs_single)

# Correct (option 2) - wrap ``example_input_array`` in another iterable (i.e. tuple or list)
inputs_wrapped = [ct.TensorType(shape=inp.shape) for inp in (example_input_array, )]
mlmodel = ct.convert(traced_model, inputs=inputs_wrapped)

addisonklinke avatar Jan 27 '22 18:01 addisonklinke

@addisonklinke - Thanks for your work here. I can reproduce the error message using the code from your first post. That is very helpful.

I'm not sure I understand your second post. Since the PyTorch model has already been traced, I don't think it makes sense to convert the model using an input which is different than the input used for tracing.

I encourage you to keep looking into this issue. I would welcome any pull requests that fixes this issue.

TobyRoseman avatar Jan 27 '22 19:01 TobyRoseman

@TobyRoseman regarding my second post - the reason I got an error from ct.convert was that I was inadvertently using a different model input than what had been used for tracing. Since PyTorch tensors are iterables, the wrong approach I demonstrated has the side effect of removing the first dimension (i.e. shape [1, 3, 288, 512] becomes [3, 288, 512]). In comparison, to preserve the model input shape you can either

  • Avoid iteration by wrapping directly in a list (option 1)
  • Wrap in another iterable prior to iteration (option 2)

Does that help clarify?

addisonklinke avatar Jan 27 '22 20:01 addisonklinke

@TobyRoseman πŸ‘‹ hello, is there any update on the issue?

ptmdmusique avatar Apr 06 '22 04:04 ptmdmusique

@TobyRoseman any update on it by any chance? I will be thankful for your support, please let me know if you need code or models to reproduce

luliuzee avatar Jul 10 '22 14:07 luliuzee

Thanks @addisonklinke for the clarification. Yeah, the coremltools behavior here is correct.

Although a minor changes in the coremltools would make issues like this easier to find. I'll put up a pull request to display shape information when it's printed.

TobyRoseman avatar Jul 31 '23 20:07 TobyRoseman