coremltools icon indicating copy to clipboard operation
coremltools copied to clipboard

PyTorch's BatchNorm2d causes script conversion error

Open JRGit4UE opened this issue 3 years ago • 4 comments

PyTorch's BatchNorm2d in
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py
causes issues on PyTorch 1.9.1 - 1.11 with Python 3.8 and CoreMLTools 5.2.0.

import coremltools as ct
import torch
from torch import nn, Tensor
import sys

class MyBatchNormModule1(nn.Module):
    def __init__(self):
        super().__init__()
        self.bn = nn.BatchNorm2d(100, affine=False)
    def forward(self, x: Tensor)->Tensor:
        res: Tensor = self.bn(x)
        return res

if __name__ == '__main__':
    t1 = MyBatchNormModule1()
    t1.eval()
    x = torch.rand(20, 100, 35, 45)
    t = torch.jit.script(t1)
    result = t1(x)
    model = ct.convert(t, inputs=[ct.TensorType(shape=x.shape)], convert_to='mlprogram', debug=True)
    sys.exit(0)

The error is an AssertionError: assert len(cond) == len(node.outputs) in converters/mil/frontend/torch/ops.py
https://github.com/apple/coremltools/blob/main/coremltools/converters/mil/frontend/torch/ops.py

@register_torch_op(torch_alias=["if"])
def _if(context, node):

JRGit4UE avatar Apr 18 '22 15:04 JRGit4UE

A possible workaround seems to be partial tracing of the BatchNorm2d layer:

self.bn = torch.jit.trace(nn.BatchNorm2d(100, affine=False), example_inputs=torch.rand(2, 100, 2, 2))

where example_inputs are of shape (2, num_features, 2, 2)

import coremltools as ct
import torch
from torch import nn, Tensor
import sys

class MyBatchNormModule1(nn.Module):
    def __init__(self):
        super().__init__()
        xxx = 100 # num_features
        self.bn = torch.jit.trace(nn.BatchNorm2d(xxx, affine=False), example_inputs=torch.rand(2, xxx, 2, 2))
    def forward(self, x: Tensor)->Tensor:
        res: Tensor = self.bn(x)
        return res

if __name__ == '__main__':
    t1 = MyBatchNormModule1()
    t1.eval()
    x = torch.rand(20, 100, 35, 45)
    t = torch.jit.script(t1)
    result1 = t1(x)
    result2 = t(x)
    check = torch.equal(result1, result2) 
    model = ct.convert(t, inputs=[ct.TensorType(shape=x.shape)], convert_to='mlprogram', debug=True)
    sys.exit(0)

JRGit4UE avatar Apr 18 '22 19:04 JRGit4UE

Although the simple workaround example above seems to work, it leads to incomplete results.. : /
May someone deeper inside batch normalization figure out a proper solution..

JRGit4UE avatar Apr 18 '22 19:04 JRGit4UE

Our support for Torch script is experimental. Conversion of your model works fine if you trace the torch model first.

As a workaround, please trace your model before conversion: just before the ct.convert call, run t = torch.jit.trace(t1, x).

TobyRoseman avatar Apr 18 '22 23:04 TobyRoseman

@TobyRoseman Thanx for your suggestion, I have traced the model and got the warning TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! which - according to https://stackoverflow.com/questions/66746307/torch-jit-trace-tracerwarning-converting-a-tensor-to-a-python-boolean-might-c -
just means that I have tried to torch.jit.trace a model with a data dependent control flow.

On the other hand, going the torch.jit.script() road works, but stops at ct.convert That's why I would really appreciate to get some Torch-Script issues fixed in CoreMLTools.

JRGit4UE avatar Apr 20 '22 10:04 JRGit4UE