coremltools icon indicating copy to clipboard operation
coremltools copied to clipboard

repeat with flexible shape (EnumeratedShapes)

Open danthe3rd opened this issue 3 years ago • 5 comments

🐞 the repeat operation does not work with dynamic inputs

  • See code: "repeat" only seems to work with constant values for the reps
  • The same problem happens with expand_as, which uses repeat underneath

Trace

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-13-933a2050e432> in <module>()
     24 ct.convert(script_model, inputs=[
     25     ct.TensorType(shape=ct.EnumeratedShapes(shapes=[(1, 1), (2, 1)])),
---> 26     ct.TensorType(shape=(1,))
     27 ])

12 frames
/usr/local/lib/python3.7/dist-packages/coremltools/converters/_converters_entry.py in convert(model, source, inputs, outputs, classifier_config, minimum_deployment_target, convert_to, **kwargs)
    180         outputs=outputs,
    181         classifier_config=classifier_config,
--> 182         **kwargs
    183     )
    184 

/usr/local/lib/python3.7/dist-packages/coremltools/converters/mil/converter.py in mil_convert(model, convert_from, convert_to, **kwargs)
    127     """
    128     proto = mil_convert_to_proto(model, convert_from, convert_to,
--> 129         ConverterRegistry, **kwargs)
    130     if convert_to == 'mil':
    131         return proto

/usr/local/lib/python3.7/dist-packages/coremltools/converters/mil/converter.py in mil_convert_to_proto(model, convert_from, convert_to, converter_registry, **kwargs)
    169     frontend_converter = frontend_converter_type()
    170 
--> 171     prog = frontend_converter(model, **kwargs)
    172     common_pass(prog)
    173 

/usr/local/lib/python3.7/dist-packages/coremltools/converters/mil/converter.py in __call__(self, *args, **kwargs)
     83         from .frontend.torch import load
     84 
---> 85         return load(*args, **kwargs)
     86 
     87 

/usr/local/lib/python3.7/dist-packages/coremltools/converters/mil/frontend/torch/load.py in load(model_spec, debug, **kwargs)
     81         raise e
     82     except Exception as e:
---> 83         raise e
     84 
     85     return prog

/usr/local/lib/python3.7/dist-packages/coremltools/converters/mil/frontend/torch/load.py in load(model_spec, debug, **kwargs)
     71 
     72     try:
---> 73         prog = converter.convert()
     74     except RuntimeError as e:
     75         if debug and "convert function" in str(e):

/usr/local/lib/python3.7/dist-packages/coremltools/converters/mil/frontend/torch/converter.py in convert(self)
    225 
    226             # Add the rest of the operations
--> 227             convert_nodes(self.context, self.graph)
    228 
    229             graph_outputs = [self.context[name] for name in self.graph.outputs]

/usr/local/lib/python3.7/dist-packages/coremltools/converters/mil/frontend/torch/ops.py in convert_nodes(context, graph)
     56             )
     57         else:
---> 58             _add_op(context, node)
     59 
     60         # We've generated all the outputs the graph needs, terminate conversion.

/usr/local/lib/python3.7/dist-packages/coremltools/converters/mil/frontend/torch/ops.py in repeat(context, node)
   2434     x = context[node.inputs[0]]
   2435     reps = context[node.inputs[1]]
-> 2436     context.add(mb.tile(x=x, reps=reps, name=node.name))
   2437 
   2438 @register_torch_op

/usr/local/lib/python3.7/dist-packages/coremltools/converters/mil/mil/ops/registry.py in add_op(cls, **kwargs)
     60             @classmethod
     61             def add_op(cls, **kwargs):
---> 62                 return cls._add_op(op_cls, **kwargs)
     63 
     64             setattr(Builder, op_type, add_op)

/usr/local/lib/python3.7/dist-packages/coremltools/converters/mil/mil/builder.py in _add_op(cls, op_cls, **kwargs)
    171             input_spec=op_cls.input_spec,
    172             op_name=kwargs["name"], before_op=before_op,
--> 173             candidate_kv=kwargs))
    174         new_op = op_cls(**kwargs)
    175 

/usr/local/lib/python3.7/dist-packages/coremltools/converters/mil/mil/builder.py in _create_vars(cls, input_spec, op_name, before_op, candidate_kv)
    152             if isinstance(in_type, (ScalarOrTensorInputType,
    153               ListOrScalarOrTensorInputType)):
--> 154                 var = cls._add_const(val, new_var_name, before_op)
    155                 update_dict[k] = var
    156 

/usr/local/lib/python3.7/dist-packages/coremltools/converters/mil/mil/builder.py in _add_const(cls, val, name, before_op)
     81     def _add_const(cls, val, name, before_op):
     82         if not is_python_value(val):
---> 83             raise ValueError("Cannot add const {}".format(val))
     84         if any_symbolic(val):
     85             msg = (

ValueError: Cannot add const [<coremltools.converters.mil.mil.var.Var object at 0x7f74bae8d1a0>]

To Reproduce

Simple repro on Google Colab:

!pip install coremltools
import torch
from torch import nn
import coremltools as ct

class Demo(nn.Module):
  def forward(self, x, y):
    bs = x.size(0)
    return y.repeat(bs)

inputs = [torch.tensor([1, 1]), torch.tensor([2])]
module = Demo()
module(*inputs) # outputs `tensor([2, 2])`

script_model = torch.jit.trace(module, inputs)

# Fixed input shape - works
ct.convert(script_model, inputs=[
    ct.TensorType(shape=(1, 1)),
    ct.TensorType(shape=(1,))
])

# Enumerated shapes - fails with error "ValueError: Cannot add const [<coremltools.converters.mil.mil.var.Var object at 0x7f74bae8d1a0>]"
ct.convert(script_model, inputs=[
    ct.TensorType(shape=ct.EnumeratedShapes(shapes=[(1, 1), (2, 1)])),
    ct.TensorType(shape=(1,))
])

System environment (please complete the following information):

  • coremltools version (4.1):
  • OS (Linux):
  • How you install python (google colab):

danthe3rd avatar Apr 14 '21 09:04 danthe3rd

Thanks for the bug report. I can reproduce this issue using your code.

TobyRoseman avatar Apr 16 '21 21:04 TobyRoseman

@TobyRoseman Do you plan to solve this bug in the next release? Otherwise, could you guide where I should look for solving this?

GrimReaperSam avatar May 14 '21 14:05 GrimReaperSam

@TobyRoseman Do you plan to solve this bug in the next release? Otherwise, could you guide where I should look for solving this?

I don't have any immediate plans to work on this. I really don't know how to solve this issue, so I'm not going to be able to give much guidance. I wish I could be more help.

@jakesabathia2 or @aseemw - can you offer any guidance here?

TobyRoseman avatar May 14 '21 21:05 TobyRoseman

Any updates on this?

askaradeniz avatar Nov 01 '21 12:11 askaradeniz

As a workaround, you can use torch.cat with the same tensor n times. You can also transform repeat as a matrix multiplication operation (with diagonal of ones).

spolezhaev avatar Feb 11 '22 15:02 spolezhaev