BiRefNet icon indicating copy to clipboard operation
BiRefNet copied to clipboard

How can I convert this model to the ncnn format? Directly converting the compiled ONNX model using pnnx throws an exception.

Open pangxiaobin opened this issue 1 month ago • 3 comments

I use pnnx convert this model is raise error

pnnx  BiRefNet-matting-epoch_100.onnx inputshape=[1,3,1024,1024] fp16=1
fp16 = 1
optlevel = 2
device = cpu
inputshape = [1,3,1024,1024]f32
inputshape2 = 
customop = 
moduleop = 
get inputshape from traced inputs
inputshape = [1,3,1024,1024]f32
inputshape2 = 
############# pass_level0 onnx 
inline_containers ...                 0.10ms
eliminate_noop ...                   47.50ms
fold_constants ...                2025-10-31 15:51:56.084163 [W:onnxruntime:pnnx, cpuid_info.cc:91 LogEarlyWarning] Unknown CPU vendor. cpuinfo_vendor value: 0
ort CreateSession failed Node (/decoder/Split_33) Op (Split) [ShapeInferenceError] Cannot parse data from external tensors. Please load external data into raw data for tensor: /decoder/Constant_1066_output_0

Has anyone successfully converted it? Could you give me an example?

pangxiaobin avatar Oct 31 '25 07:10 pangxiaobin

You can try the notebook for ONNX conversion in the tutorials folder.

ZhengPeng7 avatar Oct 31 '25 19:10 ZhengPeng7

You can try the notebook for ONNX conversion in the tutorials folder.

I made an attempt, but failed.

# in BirefNet Folde
# export.py
import torch
from utils import check_state_dict
from models.birefnet import BiRefNet


with open('deform_conv2d_onnx_exporter.py') as fp:
    file_lines = fp.read()

file_lines = file_lines.replace(
    "return sym_help._get_tensor_dim_size(tensor, dim)",
    '''
    tensor_dim_size = sym_help._get_tensor_dim_size(tensor, dim)
    if tensor_dim_size == None and (dim == 2 or dim == 3):
        import typing
        from torch import _C

        x_type = typing.cast(_C.TensorType, tensor.type())
        x_strides = x_type.strides()

        tensor_dim_size = x_strides[2] if dim == 3 else x_strides[1] // x_strides[2]
    elif tensor_dim_size == None and (dim == 0):
        import typing
        from torch import _C

        x_type = typing.cast(_C.TensorType, tensor.type())
        x_strides = x_type.strides()
        tensor_dim_size = x_strides[3]

    return tensor_dim_size
    ''',
)

with open('deform_conv2d_onnx_exporter.py', mode="w") as fp:
    fp.write(file_lines)

from torchvision.ops.deform_conv import DeformConv2d
import deform_conv2d_onnx_exporter
deform_conv2d_onnx_exporter.register_deform_conv2d_onnx_op()

birefnet = BiRefNet(bb_pretrained=False)
# state_dict = torch.load('BiRefNet-general-epoch_244.pth', map_location='cpu')
state_dict = torch.load('BiRefNet-matting-epoch_100.pth', map_location='cpu')
state_dict = check_state_dict(state_dict)
birefnet.load_state_dict(state_dict)

torch.set_float32_matmul_precision(['high', 'highest'][0])

birefnet.to('cpu')
_ = birefnet.eval()


x = torch.rand(1, 3, 1024, 1024)

# You could try disabling checking when tracing raises error
# mod = torch.jit.trace(net, x, check_trace=False)
mode = torch.jit.trace(birefnet, x, strict=False)


mode.save("birefnet.pt")
print('save done')


# You could also try exporting to the good-old onnx
torch.onnx.export(birefnet, x, 'birefnet.onnx')

venv

Package            Version
------------------ ------------
anyio              4.11.0
certifi            2025.10.5
charset-normalizer 3.4.4
click              8.3.0
coloredlogs        15.0.1
einops             0.8.1
filelock           3.20.0
flatbuffers        25.9.23
fsspec             2025.9.0
h11                0.16.0
hf-xet             1.2.0
httpcore           1.0.9
httpx              0.28.1
huggingface-hub    1.0.1
humanfriendly      10.0
idna               3.11
Jinja2             3.1.6
kornia             0.8.1
kornia_rs          0.1.9
loguru             0.7.3
MarkupSafe         3.0.3
ml_dtypes          0.5.3
mpmath             1.3.0
ncnn               1.0.20250916
networkx           3.5
numpy              1.26.4
onnx               1.19.1
onnxruntime        1.23.1
opencv-python      4.11.0.86
packaging          25.0
pillow             12.0.0
pip                25.3
pnnx               20251031
portalocker        3.2.0
protobuf           6.33.0
PyYAML             6.0.3
requests           2.32.5
safetensors        0.6.2
scipy              1.16.3
setuptools         80.9.0
shellingham        1.5.4
sniffio            1.3.1
sympy              1.14.0
timm               1.0.21
torch              2.2.2
torchvision        0.17.2
tqdm               4.67.1
typer-slim         0.20.0
typing_extensions  4.15.0
urllib3            2.5.0

export nccn

https://github.com/pnnx/pnnx

pnnx birefnet.pt "inputshape=[1,3,1024,1024]"

test result

python birefnet_ncnn.py 
layer F.scaled_dot_product_attention not exists or registered
[1]    99823 segmentation fault  python birefnet_ncnn.py

ncnn hash supported F.scaled_dot_product_attention

I don't know much about Torch and NCNN. Could someone help me out?

pangxiaobin avatar Nov 06 '25 16:11 pangxiaobin

Hey, what's ncnn? I didn't use it in my code.

Oh, if it's caused by the SDPA, you can first turn the enable_SDPA in the config.py to False. It has 100% compatibility, so you can easily modify it.

ZhengPeng7 avatar Nov 06 '25 16:11 ZhengPeng7