coremltools icon indicating copy to clipboard operation
coremltools copied to clipboard

Support upsample_trilinear3d

Open likesum opened this issue 2 years ago • 2 comments

  • Name of layer type: upsample_trilinear3d
  • Is this a PyTorch or a TensorFlow layer type: PyTorch
  • Your version of coremltools: 6.2
  • Your version of PyTorch/TensorFlow: 1.12.1
  • Impact of supporting this layer type. Why is adding support for this layer type important? Is it necessary to support a popular model or use case? Trilinear 3d upsampling is commonly used in 3d networks, volumetric rendering, 3d look-up tables and many other cases. Pytorch supports trilinear mode in torch.nn.functional.interpolate, torch.nn.Upsample.

Example to reproduce:

import torch
import coremltools as ct
import torch.nn.functional as F


class Net(torch.nn.Module):

  def forward(self, x):
    return F.interpolate(x, scale_factor=2.0, mode="trilinear")


class Net2(torch.nn.Module):

  def __init__(self) -> None:
    super().__init__()
    self.upsample3d = torch.nn.Upsample(scale_factor=2.0, mode="trilinear")

  def forward(self, x):
    return self.upsample3d(x)


input_tensor = torch.zeros([1, 8, 16, 16, 16], dtype=torch.float32)

# Check torch.nn.functional.interpolate
torch_model = Net()
traced_model = torch.jit.trace(torch_model, input_tensor)

model_ct = ct.convert(traced_model,
                      inputs=[ct.TensorType(shape=input_tensor.shape)])

# Check torch.nn.Upsample
torch_model = Net2()
traced_model = torch.jit.trace(torch_model, input_tensor)

model_ct = ct.convert(traced_model,
                      inputs=[ct.TensorType(shape=input_tensor.shape)])

Error message:

RuntimeError: PyTorch convert function for op 'upsample_trilinear3d' not implemented.

likesum avatar Feb 24 '23 18:02 likesum

With the existence of Conv3D in coremltools, support for 3D upsampling layers would be logical. This is much needed for medical image analysis, video analysis, and other volumetric applications. I tried to implement that myself in coremltools, but I think that CoreML itself does not support 3D upsampling. I got stuck here:

/Users/laves/projects/coremltools/coremltools/models/model.py:154: 
RuntimeWarning: You will not be able to run predict() on this Core ML model. 
Underlying exception message was: Error compiling model: "Failed to parse the 
model specification. Error: Unable to parse ML Program: in operation op_5_cast: 
For operation of type 'upsample_nearest_neighbor' number of inputs must be 
within the range (inclusive): 3 : 3. Provided 4".

One could hack at least upsample_nearest3d for integer scales using mb.conv_transpose with kernel size SxSxS filled with 1 and strides S, S, S, where S is the scale factor:

@register_torch_op
def upsample_nearest3d(context, node):
    inputs = _get_inputs(context, node, expected=3)
    x = inputs[0]
    s = inputs[2]
    
    c = x.shape[1]
    s_d, s_w, s_h = map(int, s.val)
    
    x = mb.conv_transpose(
        x=x,
        weight=np.ones((c, 1, s_d, s_w, s_h)),
        strides=[s_d, s_w, s_h],
        groups=c,
        name=node.name
    )
    
    context.add(x)

mlaves avatar Aug 03 '23 10:08 mlaves

@mlaves - to request changes to the Core ML Framework, please use the Feedback Assistant.

TobyRoseman avatar Aug 03 '23 22:08 TobyRoseman