coremltools
coremltools copied to clipboard
PyTorch's BatchNorm2d causes script conversion error
PyTorch's BatchNorm2d in
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py
causes issues on PyTorch 1.9.1 - 1.11 with Python 3.8 and CoreMLTools 5.2.0.
import coremltools as ct
import torch
from torch import nn, Tensor
import sys
class MyBatchNormModule1(nn.Module):
def __init__(self):
super().__init__()
self.bn = nn.BatchNorm2d(100, affine=False)
def forward(self, x: Tensor)->Tensor:
res: Tensor = self.bn(x)
return res
if __name__ == '__main__':
t1 = MyBatchNormModule1()
t1.eval()
x = torch.rand(20, 100, 35, 45)
t = torch.jit.script(t1)
result = t1(x)
model = ct.convert(t, inputs=[ct.TensorType(shape=x.shape)], convert_to='mlprogram', debug=True)
sys.exit(0)
The error is an AssertionError: assert len(cond) == len(node.outputs) in converters/mil/frontend/torch/ops.py
https://github.com/apple/coremltools/blob/main/coremltools/converters/mil/frontend/torch/ops.py
@register_torch_op(torch_alias=["if"])
def _if(context, node):
A possible workaround seems to be partial tracing of the BatchNorm2d layer:
self.bn = torch.jit.trace(nn.BatchNorm2d(100, affine=False), example_inputs=torch.rand(2, 100, 2, 2))
where example_inputs are of shape (2, num_features, 2, 2)
import coremltools as ct
import torch
from torch import nn, Tensor
import sys
class MyBatchNormModule1(nn.Module):
def __init__(self):
super().__init__()
xxx = 100 # num_features
self.bn = torch.jit.trace(nn.BatchNorm2d(xxx, affine=False), example_inputs=torch.rand(2, xxx, 2, 2))
def forward(self, x: Tensor)->Tensor:
res: Tensor = self.bn(x)
return res
if __name__ == '__main__':
t1 = MyBatchNormModule1()
t1.eval()
x = torch.rand(20, 100, 35, 45)
t = torch.jit.script(t1)
result1 = t1(x)
result2 = t(x)
check = torch.equal(result1, result2)
model = ct.convert(t, inputs=[ct.TensorType(shape=x.shape)], convert_to='mlprogram', debug=True)
sys.exit(0)
Although the simple workaround example above seems to work, it leads to incomplete results.. : /
May someone deeper inside batch normalization figure out a proper solution..
Our support for Torch script is experimental. Conversion of your model works fine if you trace the torch model first.
As a workaround, please trace your model before conversion: just before the ct.convert call, run t = torch.jit.trace(t1, x).
@TobyRoseman Thanx for your suggestion,
I have traced the model and got the warning
TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
which - according to https://stackoverflow.com/questions/66746307/torch-jit-trace-tracerwarning-converting-a-tensor-to-a-python-boolean-might-c -
just means that I have tried to torch.jit.trace a model with a data dependent control flow.
On the other hand, going the torch.jit.script() road works, but stops at ct.convert
That's why I would really appreciate to get some Torch-Script issues fixed in CoreMLTools.