coremltools icon indicating copy to clipboard operation
coremltools copied to clipboard

Please add support for `torch.tensor_split`

Open mallman opened this issue 1 year ago • 3 comments

  • Name of layer type: torch.tensor_split
  • Is this a PyTorch or a TensorFlow layer type: PyTorch
  • Your version of coremltools: 7.2
  • Your version of PyTorch/TensorFlow: 2.3.0
  • Impact of supporting this layer type. Why is adding support for this layer type important? Is it necessary to support a popular model or use case?

This layer/op is used by EVA-02, a model for image classification, segmentation and object detection. Personally, I'm interested in using it for image classification in a Mac app.

As of this writing (May 21st, 2024), various sizes of pre-trained EVA and EVA-02 models dominate the leaderboard for image classification on ImageNet 1k among the models curated by the Pytorch Image Models Hugging Face org. See https://huggingface.co/collections/timm/timm-top-20-imagenet-1k-models-655d78909af37bae32381f61

FYI, it looks like this is (essentially) the same op as tf.split from TensorFlow.

mallman avatar May 21 '24 23:05 mallman

Oh, and here's an example of a failing conversion. This is from a script I've written for converting timm models:

import coremltools as ct
import timm
import torch

model_name = "eva02_tiny_patch14_224.mim_in22k"
print(f"Creating model {model_name}")
timm_model = timm.create_model(
  model_name,
  pretrained=True,
  scriptable=False,
  exportable=True)

model = torch.nn.Sequential(
  timm_model,
  torch.nn.Softmax(1)
).eval()

input_size = timm_model.default_cfg.get("input_size")
input_shape = (1,) + input_size

print("Tracing model")
example_input = torch.randn(input_shape)
jit_model = torch.jit.trace(model, example_input)

labels_filename = "imagenet21k_wordnet_lemmas.txt"

with open(labels_filename, "r") as labels_file:
  labels = [line.strip() for line in labels_file.readlines()]

classifier_config = ct.ClassifierConfig(labels)

print("Converting model")
# Scale and bias calculations taken from Core ML Tools documentation on
# preprocessing for PyTorch
mean = list(timm_model.default_cfg.get("mean"))
std = list(timm_model.default_cfg.get("std"))
import statistics
mean_std = statistics.mean(std)
scale = 1 / (mean_std * 255)
bias = [-m / s for m, s in zip(mean, std)]
input_type = ct.ImageType(
      name="image",
      shape=input_shape,
      scale=scale,
      bias=bias)

coreml_model = ct.convert(
  jit_model,
  convert_to="mlprogram",
  inputs=[input_type],
  classifier_config=classifier_config,
  skip_model_load=True
)

coreml_model.user_defined_metadata["com.apple.coreml.model.preview.type"] = "imageClassifier"

coreml_model_file_name = f"{model_name}.mlpackage"
print(f"Saving model to {coreml_model_file_name}")

coreml_model.save(coreml_model_file_name)
print("Done!")

I believe a pip install with the timm, torch and coremltools packages will give you the right environment for running this.

You will also need a labels file, imagenet21k_wordnet_lemmas.txt, in your working directory. I'm attaching that file. imagenet21k_wordnet_lemmas.txt

mallman avatar May 21 '24 23:05 mallman

Here is a more concise way to reproduce the issue:

import torch
import coremltools as ct

class M(torch.nn.Module):
    def forward(self, x):
        return torch.tensor_split(x, 3)

x = torch.arange(8)
traced_model = torch.jit.trace(M(), x)
ct.convert(traced_model, inputs=[ct.TensorType(shape=x.shape)])

I think we should be able to use the split MIL ops at least for simple cases.

TobyRoseman avatar May 22 '24 18:05 TobyRoseman

Looks like this can be worked around by not just using torch.split but also using torch.unbind as shown here

  • https://pytorch.org/docs/stable/generated/torch.unbind.html
  • https://pytorch.org/docs/stable/generated/torch.split.html

An example of this being implemented can be seen below or in this paste (https://pastes.dev/kkaPViedJ7)

import torch
import coremltools as ct

class M(torch.nn.Module):
    def forward(self, x):
        splits = torch.split(x, x.size(0) // 3)
        return torch.unbind(torch.stack(splits))

x = torch.arange(9)  
traced_model = torch.jit.trace(M(), x)
ct.convert(traced_model, inputs=[ct.TensorType(shape=x.shape)])

teelrabbit avatar Jun 09 '24 04:06 teelrabbit