icevision icon indicating copy to clipboard operation
icevision copied to clipboard

Onnx export from mmdet retinanet ?

Open antosum opened this issue 3 years ago • 2 comments

Is there a convenient way to export models to ONNX?

I have a pretty satisfying mmdet.retinanet model with backbone resnet50_fpn_1x but I have a hard time exporting it to ONNX format

I've tried with the pytorch2onnx.py script from mmdetection but it was unsuccessful

and torch.onnx.export doesn't work:

torch.onnx.export(torch_model, 
                  (img, img_metas), 
                  "proto.onnx", 
                  export_params=True,
                  opset_version=11,  
                  do_constant_folding=True,  
                  input_names = ['input'],  
                  output_names = ['output'], 
                  dynamic_axes={'input' : {0 : 'batch_size'},
                                'output' : {0 : 'batch_size'}})

gets me

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_5522/2176814300.py in <module>
      1 # Export the model
----> 2 torch.onnx.export(torch_model,               # model being run
      3                   (img, img_metas),                         # model input (or a tuple for multiple inputs)
      4                   "proto.onnx",   # where to save the model (can be a file or file-like object)
      5                   export_params=True,        # store the trained parameter weights inside the model file

~/anaconda3/envs/icevision/lib/python3.8/site-packages/torch/onnx/__init__.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, custom_opsets, enable_onnx_checker, use_external_data_format)
    269 
    270     from torch.onnx import utils
--> 271     return utils.export(model, args, f, export_params, verbose, training,
    272                         input_names, output_names, aten, export_raw_ir,
    273                         operator_export_type, opset_version, _retain_param_name,

~/anaconda3/envs/icevision/lib/python3.8/site-packages/torch/onnx/utils.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, custom_opsets, enable_onnx_checker, use_external_data_format)
     86         else:
     87             operator_export_type = OperatorExportTypes.ONNX
---> 88     _export(model, args, f, export_params, verbose, training, input_names, output_names,
     89             operator_export_type=operator_export_type, opset_version=opset_version,
     90             _retain_param_name=_retain_param_name, do_constant_folding=do_constant_folding,

~/anaconda3/envs/icevision/lib/python3.8/site-packages/torch/onnx/utils.py in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, example_outputs, opset_version, _retain_param_name, do_constant_folding, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, fixed_batch_size, custom_opsets, add_node_names, enable_onnx_checker, use_external_data_format, onnx_shape_inference, use_new_jit_passes)
    689 
    690             graph, params_dict, torch_out = \
--> 691                 _model_to_graph(model, args, verbose, input_names,
    692                                 output_names, operator_export_type,
    693                                 example_outputs, _retain_param_name,

~/anaconda3/envs/icevision/lib/python3.8/site-packages/torch/onnx/utils.py in _model_to_graph(model, args, verbose, input_names, output_names, operator_export_type, example_outputs, _retain_param_name, do_constant_folding, _disable_torch_constant_prop, fixed_batch_size, training, use_new_jit_passes, dynamic_axes)
    452         example_outputs = (example_outputs,)
    453 
--> 454     graph, params, torch_out, module = _create_jit_graph(model, args,
    455                                                          _retain_param_name,
    456                                                          use_new_jit_passes)

~/anaconda3/envs/icevision/lib/python3.8/site-packages/torch/onnx/utils.py in _create_jit_graph(model, args, _retain_param_name, use_new_jit_passes)
    415         return graph, params, torch_out, None
    416     else:
--> 417         graph, torch_out = _trace_and_get_graph_from_model(model, args)
    418         state_dict = _unique_state_dict(model)
    419         params = list(state_dict.values())

~/anaconda3/envs/icevision/lib/python3.8/site-packages/torch/onnx/utils.py in _trace_and_get_graph_from_model(model, args)
    375 
    376     trace_graph, torch_out, inputs_states = \
--> 377         torch.jit._get_trace_graph(model, args, strict=False, _force_outplace=False, _return_inputs_states=True)
    378     warn_on_static_input_change(inputs_states)
    379 

~/anaconda3/envs/icevision/lib/python3.8/site-packages/torch/jit/_trace.py in _get_trace_graph(f, args, kwargs, strict, _force_outplace, return_inputs, _return_inputs_states)
   1137     if not isinstance(args, tuple):
   1138         args = (args,)
-> 1139     outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
   1140     return outs

~/anaconda3/envs/icevision/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

~/anaconda3/envs/icevision/lib/python3.8/site-packages/torch/jit/_trace.py in forward(self, *args)
     91 
     92     def forward(self, *args: torch.Tensor):
---> 93         in_vars, in_desc = _flatten(args)
     94         # NOTE: use full state, because we need it for BatchNorm export
     95         # This differs from the compiler path, which doesn't support it at the moment.

RuntimeError: Only tuples, lists and Variables are supported as JIT inputs/outputs. Dictionaries and strings are also accepted, but their usage is not recommended. Here, received an input of unsupported type: int

antosum avatar May 09 '22 14:05 antosum

not sure if it helps you, but have you tried torch.jit.trace()?

Modius22 avatar May 23 '22 14:05 Modius22

Can you try with torchvision.retinanet instead of mmdet.retinanet model?

dnth avatar Jun 13 '22 10:06 dnth