icevision
icevision copied to clipboard
Onnx export from mmdet retinanet ?
Is there a convenient way to export models to ONNX?
I have a pretty satisfying mmdet.retinanet model with backbone resnet50_fpn_1x but I have a hard time exporting it to ONNX format
I've tried with the pytorch2onnx.py script from mmdetection but it was unsuccessful
and torch.onnx.export doesn't work:
torch.onnx.export(torch_model,
(img, img_metas),
"proto.onnx",
export_params=True,
opset_version=11,
do_constant_folding=True,
input_names = ['input'],
output_names = ['output'],
dynamic_axes={'input' : {0 : 'batch_size'},
'output' : {0 : 'batch_size'}})
gets me
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
/tmp/ipykernel_5522/2176814300.py in <module>
1 # Export the model
----> 2 torch.onnx.export(torch_model, # model being run
3 (img, img_metas), # model input (or a tuple for multiple inputs)
4 "proto.onnx", # where to save the model (can be a file or file-like object)
5 export_params=True, # store the trained parameter weights inside the model file
~/anaconda3/envs/icevision/lib/python3.8/site-packages/torch/onnx/__init__.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, custom_opsets, enable_onnx_checker, use_external_data_format)
269
270 from torch.onnx import utils
--> 271 return utils.export(model, args, f, export_params, verbose, training,
272 input_names, output_names, aten, export_raw_ir,
273 operator_export_type, opset_version, _retain_param_name,
~/anaconda3/envs/icevision/lib/python3.8/site-packages/torch/onnx/utils.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, custom_opsets, enable_onnx_checker, use_external_data_format)
86 else:
87 operator_export_type = OperatorExportTypes.ONNX
---> 88 _export(model, args, f, export_params, verbose, training, input_names, output_names,
89 operator_export_type=operator_export_type, opset_version=opset_version,
90 _retain_param_name=_retain_param_name, do_constant_folding=do_constant_folding,
~/anaconda3/envs/icevision/lib/python3.8/site-packages/torch/onnx/utils.py in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, example_outputs, opset_version, _retain_param_name, do_constant_folding, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, fixed_batch_size, custom_opsets, add_node_names, enable_onnx_checker, use_external_data_format, onnx_shape_inference, use_new_jit_passes)
689
690 graph, params_dict, torch_out = \
--> 691 _model_to_graph(model, args, verbose, input_names,
692 output_names, operator_export_type,
693 example_outputs, _retain_param_name,
~/anaconda3/envs/icevision/lib/python3.8/site-packages/torch/onnx/utils.py in _model_to_graph(model, args, verbose, input_names, output_names, operator_export_type, example_outputs, _retain_param_name, do_constant_folding, _disable_torch_constant_prop, fixed_batch_size, training, use_new_jit_passes, dynamic_axes)
452 example_outputs = (example_outputs,)
453
--> 454 graph, params, torch_out, module = _create_jit_graph(model, args,
455 _retain_param_name,
456 use_new_jit_passes)
~/anaconda3/envs/icevision/lib/python3.8/site-packages/torch/onnx/utils.py in _create_jit_graph(model, args, _retain_param_name, use_new_jit_passes)
415 return graph, params, torch_out, None
416 else:
--> 417 graph, torch_out = _trace_and_get_graph_from_model(model, args)
418 state_dict = _unique_state_dict(model)
419 params = list(state_dict.values())
~/anaconda3/envs/icevision/lib/python3.8/site-packages/torch/onnx/utils.py in _trace_and_get_graph_from_model(model, args)
375
376 trace_graph, torch_out, inputs_states = \
--> 377 torch.jit._get_trace_graph(model, args, strict=False, _force_outplace=False, _return_inputs_states=True)
378 warn_on_static_input_change(inputs_states)
379
~/anaconda3/envs/icevision/lib/python3.8/site-packages/torch/jit/_trace.py in _get_trace_graph(f, args, kwargs, strict, _force_outplace, return_inputs, _return_inputs_states)
1137 if not isinstance(args, tuple):
1138 args = (args,)
-> 1139 outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
1140 return outs
~/anaconda3/envs/icevision/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
887 result = self._slow_forward(*input, **kwargs)
888 else:
--> 889 result = self.forward(*input, **kwargs)
890 for hook in itertools.chain(
891 _global_forward_hooks.values(),
~/anaconda3/envs/icevision/lib/python3.8/site-packages/torch/jit/_trace.py in forward(self, *args)
91
92 def forward(self, *args: torch.Tensor):
---> 93 in_vars, in_desc = _flatten(args)
94 # NOTE: use full state, because we need it for BatchNorm export
95 # This differs from the compiler path, which doesn't support it at the moment.
RuntimeError: Only tuples, lists and Variables are supported as JIT inputs/outputs. Dictionaries and strings are also accepted, but their usage is not recommended. Here, received an input of unsupported type: int
not sure if it helps you, but have you tried torch.jit.trace()?
Can you try with torchvision.retinanet instead of mmdet.retinanet model?