How can I convert this model to the ncnn format? Directly converting the compiled ONNX model using pnnx throws an exception.
I use pnnx convert this model is raise error
pnnx BiRefNet-matting-epoch_100.onnx inputshape=[1,3,1024,1024] fp16=1
fp16 = 1
optlevel = 2
device = cpu
inputshape = [1,3,1024,1024]f32
inputshape2 =
customop =
moduleop =
get inputshape from traced inputs
inputshape = [1,3,1024,1024]f32
inputshape2 =
############# pass_level0 onnx
inline_containers ... 0.10ms
eliminate_noop ... 47.50ms
fold_constants ... 2025-10-31 15:51:56.084163 [W:onnxruntime:pnnx, cpuid_info.cc:91 LogEarlyWarning] Unknown CPU vendor. cpuinfo_vendor value: 0
ort CreateSession failed Node (/decoder/Split_33) Op (Split) [ShapeInferenceError] Cannot parse data from external tensors. Please load external data into raw data for tensor: /decoder/Constant_1066_output_0
Has anyone successfully converted it? Could you give me an example?
You can try the notebook for ONNX conversion in the tutorials folder.
You can try the notebook for ONNX conversion in the tutorials folder.
I made an attempt, but failed.
# in BirefNet Folde
# export.py
import torch
from utils import check_state_dict
from models.birefnet import BiRefNet
with open('deform_conv2d_onnx_exporter.py') as fp:
file_lines = fp.read()
file_lines = file_lines.replace(
"return sym_help._get_tensor_dim_size(tensor, dim)",
'''
tensor_dim_size = sym_help._get_tensor_dim_size(tensor, dim)
if tensor_dim_size == None and (dim == 2 or dim == 3):
import typing
from torch import _C
x_type = typing.cast(_C.TensorType, tensor.type())
x_strides = x_type.strides()
tensor_dim_size = x_strides[2] if dim == 3 else x_strides[1] // x_strides[2]
elif tensor_dim_size == None and (dim == 0):
import typing
from torch import _C
x_type = typing.cast(_C.TensorType, tensor.type())
x_strides = x_type.strides()
tensor_dim_size = x_strides[3]
return tensor_dim_size
''',
)
with open('deform_conv2d_onnx_exporter.py', mode="w") as fp:
fp.write(file_lines)
from torchvision.ops.deform_conv import DeformConv2d
import deform_conv2d_onnx_exporter
deform_conv2d_onnx_exporter.register_deform_conv2d_onnx_op()
birefnet = BiRefNet(bb_pretrained=False)
# state_dict = torch.load('BiRefNet-general-epoch_244.pth', map_location='cpu')
state_dict = torch.load('BiRefNet-matting-epoch_100.pth', map_location='cpu')
state_dict = check_state_dict(state_dict)
birefnet.load_state_dict(state_dict)
torch.set_float32_matmul_precision(['high', 'highest'][0])
birefnet.to('cpu')
_ = birefnet.eval()
x = torch.rand(1, 3, 1024, 1024)
# You could try disabling checking when tracing raises error
# mod = torch.jit.trace(net, x, check_trace=False)
mode = torch.jit.trace(birefnet, x, strict=False)
mode.save("birefnet.pt")
print('save done')
# You could also try exporting to the good-old onnx
torch.onnx.export(birefnet, x, 'birefnet.onnx')
venv
Package Version
------------------ ------------
anyio 4.11.0
certifi 2025.10.5
charset-normalizer 3.4.4
click 8.3.0
coloredlogs 15.0.1
einops 0.8.1
filelock 3.20.0
flatbuffers 25.9.23
fsspec 2025.9.0
h11 0.16.0
hf-xet 1.2.0
httpcore 1.0.9
httpx 0.28.1
huggingface-hub 1.0.1
humanfriendly 10.0
idna 3.11
Jinja2 3.1.6
kornia 0.8.1
kornia_rs 0.1.9
loguru 0.7.3
MarkupSafe 3.0.3
ml_dtypes 0.5.3
mpmath 1.3.0
ncnn 1.0.20250916
networkx 3.5
numpy 1.26.4
onnx 1.19.1
onnxruntime 1.23.1
opencv-python 4.11.0.86
packaging 25.0
pillow 12.0.0
pip 25.3
pnnx 20251031
portalocker 3.2.0
protobuf 6.33.0
PyYAML 6.0.3
requests 2.32.5
safetensors 0.6.2
scipy 1.16.3
setuptools 80.9.0
shellingham 1.5.4
sniffio 1.3.1
sympy 1.14.0
timm 1.0.21
torch 2.2.2
torchvision 0.17.2
tqdm 4.67.1
typer-slim 0.20.0
typing_extensions 4.15.0
urllib3 2.5.0
export nccn
https://github.com/pnnx/pnnx
pnnx birefnet.pt "inputshape=[1,3,1024,1024]"
test result
python birefnet_ncnn.py
layer F.scaled_dot_product_attention not exists or registered
[1] 99823 segmentation fault python birefnet_ncnn.py
ncnn hash supported F.scaled_dot_product_attention
I don't know much about Torch and NCNN. Could someone help me out?
Hey, what's ncnn? I didn't use it in my code.
Oh, if it's caused by the SDPA, you can first turn the enable_SDPA in the config.py to False. It has 100% compatibility, so you can easily modify it.