coremltools icon indicating copy to clipboard operation
coremltools copied to clipboard

Feature request to support of aten::IntImplicit ops

Open w4-jihunlorenzopark opened this issue 3 weeks ago • 3 comments

🌱 Describe your Feature Request

Hello, I would like to request a feature to convert aten::IntImplicit ops.

How can this feature be used?

This is to resolve error described in another issue which is RuntimeError: PyTorch convert function for op 'intimplicit' not implemented.

From my digging, this aten::IntImplicit appears when we try to use zero dim integer tensor (such as tensor.tensor(1)) in the place requiring native integer type.

For my case, my module has slicing code in torch.jit.script function

@torch.jit.script
def some_function(iter_num, data_tensor, length_tensor)
   for i in range(iter_num):
        ...
        # here, length is 1-d tensor, so length[i] would be zero dim tensor, i.e. `torch.tensor(1)`
        item = data_tensor[..., :length_tensor[i]]
    ...

which will occur RuntimeError: PyTorch convert function for op 'intimplicit' not implemented. during calling ct.convert function. The openvino, which is for intel cpus seems to have the translation for IntImplicit ops. I believe implementing this ops will help the coremltools users convert many open source models in the wild, where authors may not be care about this error during coremltools conversion.

Describe alternatives you've considered

I am trying to figure out how can I run the same logic without seeing IntImplicit op. Any workaround will be welcomed!

Additional context

Here is a simple reproducible code to get the RuntimeError: PyTorch convert function for op 'intimplicit' not implemented. error.

import torch
import torch.nn as nn
import coremltools as ct

@torch.jit.script
def select_with_tensor(a):
    index = torch.tensor(1)
    # torch.select expects the index as native integer type, not torch.tensor type. So, it seems IntImplicit ops added when torch.jit.script compilation.
    # If we use index = 1 instead, no IntImplicit and no errors during conversion.
    return a.select(1, index)

class SimpleModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(5,10)

    def forward(self, x):
        out = self.linear(x)
        sliced_tensor = select_with_tensor(out)
        return sliced_tensor

module = SimpleModule()
module.eval()

example_input = torch.rand(1, 5)
traced_module = torch.jit.trace(module, example_input)
model = ct.convert(
    traced_module,
    convert_to="mlprogram",
    inputs=[ct.TensorType(shape=example_input.shape)]
 )

Here is the TorchScript graph of select_with_tensor

print(select_with_tensor.graph)

graph(%a.1 : Tensor):
  %4 : bool = prim::Constant[value=0]()
  %2 : NoneType = prim::Constant()
  %1 : int = prim::Constant[value=1]() # /var/folders/c5/yknsh7rd1jz7schyq8qt4lp8000507/T/ipykernel_13384/701066052.py:7:25
  %index.1 : Tensor = aten::tensor(%1, %2, %2, %4) # /var/folders/c5/yknsh7rd1jz7schyq8qt4lp8000507/T/ipykernel_13384/701066052.py:7:12
  %8 : int = aten::IntImplicit(%index.1) # /var/folders/c5/yknsh7rd1jz7schyq8qt4lp8000507/T/ipykernel_13384/701066052.py:10:11
  %9 : Tensor = aten::select(%a.1, %1, %8) # /var/folders/c5/yknsh7rd1jz7schyq8qt4lp8000507/T/ipykernel_13384/701066052.py:10:11
  return (%9)

Thanks for reading this. I am new to coremltools and its codebase, so if someone provide me references or guide about how to implement this, I am gladly willing to make PR for this.

w4-jihunlorenzopark avatar Jun 21 '24 07:06 w4-jihunlorenzopark