InstanceNorm3D missing
I would like to implement InstanceNorm3D. Given that BatchNorm3D and InstanceNorm2D are available as MIL op, this should be straightforward. However, it's not clear to me where I would start implementing this or how to extend the instance_norm MIL op to rank 5. For inference with batch size of 1, one could replace instance norm with batch norm instead.
Any advice on that is appreciated.
Edit: After digging around in the code, I see that there's a custom implementation in the case of batchnorm_3d here. I guess this has to be done for instance norm 3d as well?
I managed to make a basic implementation of instance_norm_3d as composite operation. This only implements the default InstanceNorm3D from PyTorch. I'll create a PR for this if I find the time
import coremltools as ct
from coremltools.converters.mil import Builder as mb
from coremltools.converters.mil.frontend.torch.torch_op_registry import _TORCH_OPS_REGISTRY, register_torch_op
from coremltools.converters.mil.frontend.torch.ops import _get_inputs
del _TORCH_OPS_REGISTRY["instance_norm"] # only required if over-writing an existing translation
@register_torch_op
def instance_norm(context, node):
inputs = _get_inputs(context, node, expected=9)
x = inputs[0]
weight = inputs[1]
bias = inputs[2]
eps = inputs[7]
# implement instance norm from scratch
mean = mb.reduce_mean(x=x, axes=[2, 3, 4], keep_dims=True)
shape = mean.shape
sub = mb.sub(x=x, y=mean, name=node.name)
squared = mb.mul(x=sub, y=sub)
variance = mb.reduce_mean(x=squared, axes=[2, 3, 4], keep_dims=True)
variance_eps = mb.add(x=variance, y=eps)
std = mb.sqrt(x=variance_eps)
name = node.name if weight is None and bias is None else node.name + "_div"
x = mb.real_div(x=sub, y=std, name=name)
if weight is not None:
weight_reshape = mb.reshape(x=weight, shape=shape)
name = node.name if bias is None else node.name + "_mul"
x = mb.mul(x=x, y=weight_reshape, name=name)
if bias is not None:
bias_reshape = mb.reshape(x=bias, shape=shape)
x = mb.add(x=x, y=bias_reshape, name=node.name)
context.add(x)
Sending us a PR would be great. Please include a unit test.