coremltools icon indicating copy to clipboard operation
coremltools copied to clipboard

Complex op such as `irfftn` doesn't support dynamic shapes

Open chophilip21 opened this issue 1 year ago β€’ 6 comments

🐞Describing the bug

  • Trying to convert this Pytorch lama model to CoreML, and it does work fine when the model uses Fixed Input size. However, when I want to use flexible shapes, you get shape mismatch error between real data and imag_data. Seems like it's not able to correctly process the symbolic shapes properly.

Stack Trace

DEBUG:coremltools:Adding const op '256_end_0'
INFO:coremltools:Adding op '256_end_0' of type const
DEBUG:coremltools:Downcast const op 256_end_0 dataint64 as int32
DEBUG:coremltools:Downcast const op 256_end_0 dataint64 as int32
DEBUG:coremltools:Adding const op '256_end_mask_0'
INFO:coremltools:Adding op '256_end_mask_0' of type const
DEBUG:coremltools:Adding const op '256_squeeze_mask_0'
INFO:coremltools:Adding op '256_squeeze_mask_0' of type const
INFO:coremltools:Converting op ffted3.3 : complex
INFO:coremltools:Adding op 'complex_0' of type complex
Converting PyTorch Frontend ==> MIL Ops:   5%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–                                                                                                                                                        | 152/3297 [00:00<00:04, 650.29 ops/s]
Traceback (most recent call last):
  File "/Users/philip/SkyCoreML/main.py", line 51, in <module>
    main()
  File "/Users/philip/SkyCoreML/main.py", line 43, in main
    lama.convert(args)
  File "/Users/philip/SkyCoreML/src/skycoreml/conversion/lama.py", line 124, in convert
    coreml_model = ct.convert(
  File "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/_converters_entry.py", line 530, in convert
    mlmodel = mil_convert(
  File "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/mil/converter.py", line 188, in mil_convert
    return _mil_convert(model, convert_from, convert_to, ConverterRegistry, MLModel, compute_units, **kwargs)
  File "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/mil/converter.py", line 212, in _mil_convert
    proto, mil_program = mil_convert_to_proto(
  File "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/mil/converter.py", line 286, in mil_convert_to_proto
    prog = frontend_converter(model, **kwargs)
  File "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/mil/converter.py", line 108, in __call__
    return load(*args, **kwargs)
  File "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/load.py", line 63, in load
    return _perform_torch_convert(converter, debug)
  File "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/load.py", line 102, in _perform_torch_convert
    prog = converter.convert()
  File "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/converter.py", line 439, in convert
    convert_nodes(self.context, self.graph)
  File "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/ops.py", line 92, in convert_nodes
    add_op(context, node)
  File "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/ops.py", line 5691, in complex
    result = mb.complex(real_data=real_part, imag_data=imag_part)
  File "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/mil/mil/ops/registry.py", line 183, in add_op
    return cls._add_op(op_cls_to_add, **kwargs)
  File "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/mil/mil/builder.py", line 182, in _add_op
    new_op.type_value_inference()
  File "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/mil/mil/operation.py", line 253, in type_value_inference
    output_types = self.type_inference()
  File "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/mil/mil/ops/defs/complex_dialect_ops.py", line 162, in type_inference
    raise ValueError(
ValueError: The shape of real_data ((1, 192, is32, is33)) and imag_data ((1, 192, is34, is35)) must match to construct complex data.

To Reproduce

To reproduce, clone the above repository for CoreLama, and run convert_lama.py. The only thing you would want to change is giving some kind of flexible shape to the image input and mask input, so something like:

   image_shape = ct.EnumeratedShapes(
        shapes=[[1, 3, 256, 256], [1, 3, 512, 512], [1, 3, 1024, 1024]],
        default=[1, 3, 512, 512],
    )

    mask_shape = ct.EnumeratedShapes(
        shapes=[[1, 1, 256, 256], [1, 1, 512, 512], [1, 1, 1024, 1024]],
        default=[1, 1, 512, 512],
    )

  coreml_model = ct.convert(
      jit_model,
      convert_to="mlprogram",
      compute_precision=ct.precision.FLOAT32,
      compute_units=ct.ComputeUnit.CPU_AND_GPU,
      inputs=[
          ct.ImageType(name="image",
                       shape=image_shape,
                       scale=1/255.0),
          ct.ImageType(
              name="mask",
              shape=mask_shape,
              color_layout=ct.colorlayout.GRAYSCALE)
      ],
      outputs=[ct.ImageType(name="output")],
      skip_model_load=True
  )

Things I have tried

Doesn't matter if you are tracing or scripting, it will fail as long as you enable flexible shape. It should produce mlpackage properly. I have tried multiple versions of Pytorch and CoreML tools, but no luck. Also tried going back to legacy neuralnetwork option, but below creates a model file that seems to have some kind of memory leak.

        # see if you can update the spec
        spec = ct.utils.load_spec(coreml_model_file_name)
        image = spec.description.input[0].name
        mask = spec.description.input[1].name

        # update the image
        flexible_shape_utils.set_multiarray_ndshape_range(
            spec,
            feature_name=image,
            lower_bounds=[1, 3, 256, 256],
            upper_bounds=[1, 3, 1024, 1024],
        )

        # update the mask
        flexible_shape_utils.set_multiarray_ndshape_range(
            spec,
            feature_name=mask,
            lower_bounds=[1, 1, 256, 256],
            upper_bounds=[1, 1, 1024, 1024],
        )

        # save the spec
        coreml_model_file_name = "LaMa_updated.mlmodel"
        coreml_model_file_name = os.path.join(save_dir, coreml_model_file_name)
        ct.utils.save_spec(spec, coreml_model_file_name)

System environment (please complete the following information):

  • coremltools version: 7.0b1
  • OS (e.g. MacOS version or Linux type): Mac Mini with M2 chip, macOS Sonoma
  • Any other relevant version information (e.g. PyTorch or TensorFlow version): Pytorch 2.0.1
  • Tried both Python 3.9 and 3.11.

Additional context

  • I know that support for complex number is quite limited at the moment, but being able to use flexible shape is very important. Is there any way to work around this?

chophilip21 avatar Aug 25 '23 17:08 chophilip21

Hi @chophilip21, thank you for reporting this issue with detailed info!

You are right, when checking shape in complex op, it should also consider dynamic/symbolic shape.

A quick fix on your side would be to change "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/mil/mil/ops/defs/complex_dialect_ops.py", line 162, in type_inference:

from coremltools.converters.mil.mil.types.symbolic import any_symbolic, is_symbolic

...

def type_inference(self):
  # Don't compare the shape directly if there is symbolic shape.
  if any_symbolic(self.real_data.shape) or any_symbolic(self.imag_data.shape):
        # Checking the non-symbolic dim to make sure they match.
        for dim, dim_size in enumerate(self.real_data.shape):
            if not is_symbolic(dim_size):
                assert dim_size == self.imag_data.shape[dim]
  else:
        # Here is the original shape checking logic.
        if self.real_data.shape != self.imag_data.shape:
            ...

Notice it's just a draft, but should explained the logic. Thanks!

junpeiz avatar Aug 25 '23 21:08 junpeiz

@junpeiz

Thank you very much for your help!

 def type_inference(self):
        # Don't compare the shape directly if there is symbolic shape.
        if any_symbolic(self.real_data.shape) or any_symbolic(self.imag_data.shape):
        # Checking the non-symbolic dim to make sure they match.
            for dim, dim_size in enumerate(self.real_data.shape):
                if not is_symbolic(dim_size):
                    assert dim_size == self.imag_data.shape[dim]
        else:
            if self.real_data.shape != self.imag_data.shape:
                raise ValueError(
                    f"The shape of real_data ({self.real_data.shape}) and imag_data "
                    f"({self.imag_data.shape}) must match to construct complex data."
                )
        return types.tensor(
            infer_complex_dtype(self.real_data.dtype, self.imag_data.dtype),
            self.real_data.shape,
        )

I can bypass the check there, but the conversion unfortunately does not get far after that point. You get another complex data issue here.

DEBUG:coremltools:Adding const op 'gather_6_axis_0'
INFO:coremltools:Adding op 'gather_6_axis_0' of type const
INFO:coremltools:Converting op 259 : listconstruct
INFO:coremltools:Converting op 260 : listconstruct
INFO:coremltools:Adding op '260' of type const
INFO:coremltools:Converting op output.3 : fft_irfftn
INFO:coremltools:Adding op 'complex_irfftn_0' of type complex_irfftn
Converting PyTorch Frontend ==> MIL Ops:   5%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š                                                                                                                     | 156/3294 [00:00<00:01, 1915.99 ops/s]
Traceback (most recent call last):
  File "/Users/philip/SkyCoreML/main.py", line 40, in <module>
    main()
  File "/Users/philip/SkyCoreML/main.py", line 35, in main
    lama.convert(args)
  File "/Users/philip/SkyCoreML/src/skycoreml/conversion/lama.py", line 136, in convert
    coreml_model = ct.convert(
  File "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/_converters_entry.py", line 551, in convert
    mlmodel = mil_convert(
  File "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/mil/converter.py", line 188, in mil_convert
    return _mil_convert(model, convert_from, convert_to, ConverterRegistry, MLModel, compute_units, **kwargs)
  File "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/mil/converter.py", line 212, in _mil_convert
    proto, mil_program = mil_convert_to_proto(
  File "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/mil/converter.py", line 286, in mil_convert_to_proto
    prog = frontend_converter(model, **kwargs)
  File "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/mil/converter.py", line 108, in __call__
    return load(*args, **kwargs)
  File "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/load.py", line 75, in load
    return _perform_torch_convert(converter, debug)
  File "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/load.py", line 114, in _perform_torch_convert
    prog = converter.convert()
  File "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/converter.py", line 481, in convert
    convert_nodes(self.context, self.graph)
  File "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/ops.py", line 93, in convert_nodes
    add_op(context, node)
  File "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/ops.py", line 5893, in fft_irfftn
    irfftn_res = mb.complex_irfftn(data=input_data, shapes=shapes, dims=dims, norm=norm)
  File "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/mil/mil/ops/registry.py", line 183, in add_op
    return cls._add_op(op_cls_to_add, **kwargs)
  File "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/mil/mil/builder.py", line 164, in _add_op
    kwargs.update(cls._create_vars(
  File "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/mil/mil/builder.py", line 147, in _create_vars
    var = cls._add_const(val, new_var_name, before_op)
  File "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/mil/mil/builder.py", line 76, in _add_const
    raise ValueError("Cannot add const {}".format(val))
ValueError: Cannot add const [<coremltools.converters.mil.mil.var.Var object at 0x16b983700>, <coremltools.converters.mil.mil.var.Var object at 0x16b9838e0>]

And above is pointing at:

@register_torch_op
def fft_irfftn(context, node):
    """Lowers torch.fft.irfftn by the dialect op `complex_irfftn` from complex_dialect_ops.py."""
    input_data, shapes, dims, norm = _get_inputs(context, node, expected=[4])
    irfftn_res = mb.complex_irfftn(data=input_data, shapes=shapes, dims=dims, norm=norm)
    context.add(irfftn_res, node.name)

This may be because we are just trying to bypass the checks above. Or could there be another reason for this?

chophilip21 avatar Aug 25 '23 22:08 chophilip21

@chophilip21 You are right, the issue is that the input shapes to mb.complex_irfftn is dynamic. However, in the op definition in coremltools/coremltools/converters/mil/mil/ops/defs/complex_dialect_ops.py, the shapes needs to be a const, and that's why the error says "Cannot add const".

Here is a minimum example to reproduce the issue (so you can also try it on your end to avoid debugging a large llama model), which could be placed in coremltools/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py:

    @pytest.mark.parametrize(
        "compute_unit, backend",
        itertools.product(compute_units, backends),
    )
    def test_fftn_dynamic_shape(
        self, compute_unit: ct.ComputeUnit, backend
    ):
        class FftnModel(torch.nn.Module):
            def forward(self, x, y):
                x = torch.complex(x, x)
                return torch.fft.irfftn(x, s=y.shape, dim=None, norm=None)

        input_data = [torch.rand(2, 3, 4), torch.rand(1, 4)]
        input_type = [
            ct.TensorType(shape=(2, 3, RangeDim(1, 10))),
            ct.TensorType(shape=(RangeDim(1, 10), RangeDim(1, 10))),
        ]
        TorchBaseTest.run_compare_torch(
            input_data, FftnModel(), backend=backend, compute_unit=compute_unit, input_as_shape=False, converter_input_type=input_type,
        )

We will use this thread as a feature request for "Supporting dynamic shapes in complex irfftn op". Thank you for reporting this issue!

junpeiz avatar Aug 28 '23 05:08 junpeiz

Thank you, I look forward to hearing updates on this!

chophilip21 avatar Aug 28 '23 16:08 chophilip21

Hello everyone, how is the requirement? Does coreml supports dynamic shapes already?

StevenSK-king avatar Oct 24 '23 03:10 StevenSK-king

Any update on this issue?

StevenSK-king avatar Dec 13 '23 15:12 StevenSK-king