Crash while converting models with torch.nn.BatchNorm3d layers
πDescribing the bug
- Make sure to only create an issue here for bugs in the coremltools Python package. If this is a bug with the Core ML Framework or Xcode, please submit your bug here: https://developer.apple.com/bug-reporting/
- Provide a clear and consise description of the bug.
Stack Trace
python repro.py
Converting PyTorch Frontend ==> MIL Ops: 80%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | 4/5 [00:00<00:00, 2077.93 ops/s]
Running MIL frontend_pytorch pipeline: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 5/5 [00:00<00:00, 8973.69 passes/s]
Running MIL default pipeline: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 79/79 [00:00<00:00, 5920.24 passes/s]
Running MIL backend_mlprogram pipeline: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 12/12 [00:00<00:00, 11594.48 passes/s]
loc("tensor<fp16, [1, 32, 16, 64, 64]> var_11_cast_fp16 = batch_norm(beta = tensor<fp16, [32]>(BLOBFILE(path = string(\22/private/var/folders/t3/3lnsqlv128ddvnvrh9rvpjmc0000gn/T/tmp2ulm8axj.mlmodelc/weights/weight.bin\22), offset = uint64(448))), epsilon = fp16(1.00135803e-05), gamma = tensor<fp16, [32]>(BLOBFILE(path = string(\22/private/var/folders/t3/3lnsqlv128ddvnvrh9rvpjmc0000gn/T/tmp2ulm8axj.mlmodelc/weights/weight.bin\22), offset = uint64(320))), mean = tensor<fp16, [32]>(BLOBFILE(path = string(\22/private/var/folders/t3/3lnsqlv128ddvnvrh9rvpjmc0000gn/T/tmp2ulm8axj.mlmodelc/weights/weight.bin\22), offset = uint64(64))), variance = tensor<fp16, [32]>(BLOBFILE(path = string(\22/private/var/folders/t3/3lnsqlv128ddvnvrh9rvpjmc0000gn/T/tmp2ulm8axj.mlmodelc/weights/weight.bin\22), offset = uint64(192))), x = x_to_fp16)[milId = uint64(2), name = string(\22op_11_cast_fp16\22)]; - /private/var/folders/t3/3lnsqlv128ddvnvrh9rvpjmc0000gn/T/tmp2ulm8axj.mlmodelc/model.mil":12:12): error: output type 'tensor<1x32x16x64x64xf16>' and mean type 'tensor<1x0x1x1x1329168176xf16>' are not broadcast compatible
LLVM ERROR: Failed to infer result type(s).
zsh: abort python repro.py
/opt/homebrew/anaconda3/envs/coremltools-env/lib/python3.10/multiprocessing/resource_tracker.py:224: UserWarning: resource_tracker: There appear to be 1 leaked semaphore objects to clean up at shutdown
warnings.warn('resource_tracker: There appear to be %d '
To Reproduce
- Please add a minimal code example that can reproduce the error when running it.
import coremltools as ct
import numpy as np
import torch
class Model(torch.nn.Module):
def __init__(self, n_features):
super().__init__()
self.norm = torch.nn.BatchNorm3d(n_features)
def forward(self, x):
return self.norm(x)
model = Model(32).eval()
features = torch.randn((1, 32, 16, 64, 64))
with torch.no_grad():
mlmodel = ct.convert(
torch.jit.trace(model, features),
inputs=[ct.TensorType(name="x", shape=features.shape)],
outputs=[ct.TensorType(name="out")],
convert_to="mlprogram",
)
- If the model conversion succeeds, but there is a numerical mismatch in predictions, please include the code used for comparisons.
System environment (please complete the following information):
- coremltools version: 8.0b1
- OS (e.g. MacOS version or Linux type): macOS 15.0 Beta (24A5289h)
- Any other relevant version information (e.g. PyTorch or TensorFlow version): PyTorch 2.3.0
Additional context
- The crash does not occur if
BatchNorm3dlayer is initialized withaffine=False(the default value for that argument isTrue).
I can reproduce this issue.
I think this is a bug in the Core ML Framework. The ct.convert call works if you convert_to either neuralnetwork or milinternal. It's loading the converted mlprogram that's the issue; if you pass skip_model_load=True, the ct.convert call also succeeds.
@TobyRoseman I'm going to give your suggestion a shot! With the wrapper library https://github.com/huggingface/exporters, which myself and others get the very same segmentation fault trying to convert any 'large' model. Gets all the way to the end and then the process gets killed.
Similar issue opened on that repo here.
Edit: this worked! Converted a 'large' model, no process killed.