coremltools icon indicating copy to clipboard operation
coremltools copied to clipboard

InstanceNorm3D missing

Open mlaves opened this issue 2 years ago • 2 comments

I would like to implement InstanceNorm3D. Given that BatchNorm3D and InstanceNorm2D are available as MIL op, this should be straightforward. However, it's not clear to me where I would start implementing this or how to extend the instance_norm MIL op to rank 5. For inference with batch size of 1, one could replace instance norm with batch norm instead. Any advice on that is appreciated.

Edit: After digging around in the code, I see that there's a custom implementation in the case of batchnorm_3d here. I guess this has to be done for instance norm 3d as well?

mlaves avatar Sep 24 '23 14:09 mlaves

I managed to make a basic implementation of instance_norm_3d as composite operation. This only implements the default InstanceNorm3D from PyTorch. I'll create a PR for this if I find the time

import coremltools as ct
from coremltools.converters.mil import Builder as mb

from coremltools.converters.mil.frontend.torch.torch_op_registry import _TORCH_OPS_REGISTRY, register_torch_op
from coremltools.converters.mil.frontend.torch.ops import _get_inputs

del _TORCH_OPS_REGISTRY["instance_norm"] # only required if over-writing an existing translation

@register_torch_op
def instance_norm(context, node):
    inputs = _get_inputs(context, node, expected=9)
    x = inputs[0]
    weight = inputs[1]
    bias = inputs[2]
    eps = inputs[7]
    
    # implement instance norm from scratch
    mean = mb.reduce_mean(x=x, axes=[2, 3, 4], keep_dims=True)
    shape = mean.shape
    sub = mb.sub(x=x, y=mean, name=node.name)
    squared = mb.mul(x=sub, y=sub)
    variance = mb.reduce_mean(x=squared, axes=[2, 3, 4], keep_dims=True)
    variance_eps = mb.add(x=variance, y=eps)
    std = mb.sqrt(x=variance_eps)

    name = node.name if weight is None and bias is None else node.name + "_div"
    x = mb.real_div(x=sub, y=std, name=name)

    if weight is not None:
        weight_reshape = mb.reshape(x=weight, shape=shape)
        name = node.name if bias is None else node.name + "_mul"
        x = mb.mul(x=x, y=weight_reshape, name=name)

    if bias is not None:
        bias_reshape = mb.reshape(x=bias, shape=shape)
        x = mb.add(x=x, y=bias_reshape, name=node.name)
        
    context.add(x)

mlaves avatar Sep 25 '23 21:09 mlaves

Sending us a PR would be great. Please include a unit test.

TobyRoseman avatar Sep 25 '23 22:09 TobyRoseman