vision icon indicating copy to clipboard operation
vision copied to clipboard

MaskRCNN from ScriptModule to ONNX - Unknown Type BoxCoder

Open Nuno-Mota opened this issue 2 years ago • 7 comments

🐛 Describe the bug

While attempting to create an ONNX version of Maskrcnn, starting from a ScriptModule, an error occurs, indicating that __torch__.torchvision.models.detection._utils.BoxCoder is an unknown type.

MWE:

import torch
from torchvision.models.detection.mask_rcnn import maskrcnn_resnet50_fpn

model = maskrcnn_resnet50_fpn()
model.eval()
script_model = torch.jit.script(model)
example_image = torch.rand((3, 800, 1000))
torch.onnx.export(
	script_model,
	[example_image],
	"test.onnx",
	example_outputs=script_model([example_image])[1], # index 0 is losses
	opset_version = 11
)

Error traceback:

Traceback (most recent call last):
  File "/home/nmota/test_onnx.py", line 8, in <module>
    torch.onnx.export(
  File "/usr/lib/python3.9/site-packages/torch/onnx/__init__.py", line 275, in export
    return utils.export(model, args, f, export_params, verbose, training,
  File "/usr/lib/python3.9/site-packages/torch/onnx/utils.py", line 88, in export
    _export(model, args, f, export_params, verbose, training, input_names, output_names,
  File "/usr/lib/python3.9/site-packages/torch/onnx/utils.py", line 689, in _export
    _model_to_graph(model, args, verbose, input_names,
  File "/usr/lib/python3.9/site-packages/torch/onnx/utils.py", line 458, in _model_to_graph
    graph, params, torch_out, module = _create_jit_graph(model, args,
  File "/usr/lib/python3.9/site-packages/torch/onnx/utils.py", line 402, in _create_jit_graph
    module, params = torch._C._jit_onnx_list_model_parameters(freezed_m)
RuntimeError: 
Unknown type __torch__.torchvision.models.detection._utils.BoxCoder (of Python compilation unit at: 0x55bbdb787f00) encountered in handling model params. This class type does not extend __getstate__ method.:

✗ - status code 1

Unfortunately, I cannot test with a more recent version. Is this something that has been fixed recently?

Versions

PyTorch version: 1.9.0
Is debug build: False
CUDA used to build PyTorch: 11.4
ROCM used to build PyTorch: N/A

OS: Arch Linux (x86_64)
GCC version: (GCC) 11.1.0
Clang version: 12.0.1
CMake version: version 3.21.2
Libc version: glibc-2.33

Python version: 3.9.6 (default, Jun 30 2021, 10:22:16)  [GCC 11.1.0] (64-bit runtime)
Python platform: Linux-5.13.13-arch1-1-x86_64-with-glibc2.33
Is CUDA available: True
CUDA runtime version: 11.4.100
GPU models and configuration: 
GPU 0: NVIDIA TITAN X (Pascal)

Nvidia driver version: 470.63.01
cuDNN version: Probably one of the following:
/usr/lib/libcudnn.so.8.2.2
/usr/lib/libcudnn_adv_infer.so.8.2.2
/usr/lib/libcudnn_adv_train.so.8.2.2
/usr/lib/libcudnn_cnn_infer.so.8.2.2
/usr/lib/libcudnn_cnn_train.so.8.2.2
/usr/lib/libcudnn_ops_infer.so.8.2.2
/usr/lib/libcudnn_ops_train.so.8.2.2
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.20.3
[pip3] torch==1.9.0
[pip3] torchvision==0.10.0a0
[conda] Could not collect

cc @neginraoof

Nuno-Mota avatar Apr 08 '22 11:04 Nuno-Mota

@Nuno-Mota I'm not too familiar with ONNX but is there a reason you are JIT-scripting the model prior to exporting it?

The intended way is to do something like:

import torch
from torchvision.models.detection import *

model = maskrcnn_resnet50_fpn(weights_backbone=None)
model.eval()
example_image = torch.rand((3, 800, 1000))
torch.onnx.export(
	model,
	[example_image],
	"test.onnx",
	opset_version=11,
)

This works fine in the latest version.

datumbox avatar Apr 08 '22 13:04 datumbox

@datumbox, the idea is to try to preserve dynamic control flow, as mentioned in the docs.

Nuno-Mota avatar Apr 08 '22 13:04 Nuno-Mota

@Nuno-Mota thanks for clarifying. As I said, I'm not too familiar with ONNX and I'm trying to understand the status of the support from the existing tests. Upon investigating, I saw that we don't test against the jit-scripted versions which means according to the quoted doc, that we actually trace the model.

@fmassa Do you have any context concerning this choice? Is this deliberate? As far as I understand the detection models are not traceable due to their loops.

datumbox avatar Apr 08 '22 14:04 datumbox

@Nuno-Mota I have the same issue, but this time with FasterRCNN. Have you found a solution, please?

medric49 avatar Oct 01 '22 21:10 medric49

Same issue in FasterRcnn convertion, any update?

RunnerZhong avatar May 15 '23 10:05 RunnerZhong

@medric49 @RunnerZhong I found a solution if you are using the pretrained FasterRcnn avaliable from pytorch. It involves loading the scripted model, extracting the weights, applying them to the pretrained model, and then converting to onnx.

assuming your model is made from a template similar to this:

from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

def get_model(num_classes):
    frcnn_model = fasterrcnn_resnet50_fpn_v2(weights='COCO_V1')

    in_features = frcnn_model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    frcnn_model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    return frcnn_model

Then you can create an onnx exportable model by doing the following:

import torch

state_dict = torch.load("jit_model.pt").state_dict()
model = get_model(n) # n being number of output classes
model.load_state_dict(state_dict)

torch.onnx.export(model) # plus whichever params you want here.

adamzenith avatar Nov 15 '23 14:11 adamzenith

Just came across this, I can repro the same issue (same error message @Nuno-Mota reported) using torchvision 0.18.1. Any updates/further solution attempts to run torch.jit.script on the models?

stes avatar Jul 07 '24 21:07 stes