SiamMask icon indicating copy to clipboard operation
SiamMask copied to clipboard

Need help with conversion to ONNX

Open surajs52 opened this issue 1 year ago • 0 comments

Hey @foolwood, i need help with conversion to ONNX format. My python script using torch.onnx.export() for conversion looks like this:

from tools.test import * #from siammask.models import Custom from custom import Custom

parser = argparse.ArgumentParser(description='PyTorch Tracking Demo')

parser.add_argument('--resume', default='', type=str, required=True, metavar='PATH',help='path to latest checkpoint (default: none)') parser.add_argument('--config', dest='config', default='config_davis.json', help='hyper-parameter of SiamMask in json format') #parser.add_argument('--base_path', default='../../data/tennis', help='datasets') #parser.add_argument('--cpu', action='store_true', help='cpu mode') args = parser.parse_args()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') torch.backends.cudnn.benchmark = True

cfg = load_config(args) siammask = Custom(anchors=cfg['anchors'])

siammask.load_state_dict(torch.load('SiamMask_DAVIS.pth')["state_dict"])

siammask.eval().to(device) siammask.half()

template = torch.randn(1, 3, 127, 127).to(device).half() search = torch.randn(1, 3, 255, 255).to(device).half() label_cls = torch.randn(1, 1, 5).to(device).half() input_dict = {'template': template, 'search': search} #, 'label_cls': label_cls}

torch.onnx.export(siammask, input_dict, "SiamMask_DAVIS_half_test.onnx", input_names=['template', 'search'], opset_version=11, do_constant_folding=True, verbose=True, output_names=['rpn_pred_cls', 'rpn_pred_loc', 'pred_mask'], dynamic_axes={'search': {0: 'batch_size'}, # if you want batch size to be dynamic 'rpn_pred_cls': {0: 'batch_size'}, 'rpn_pred_loc': {0: 'batch_size'}, 'pred_mask': {0: 'batch_size'}})

The output looks like this: [2024-02-14 15:40:21,552-rk0-features.py# 66] Current training 0 layers:

[2024-02-14 15:40:21,554-rk0-features.py# 66] Current training 1 layers:

====== Diagnostic Run torch.onnx.export version 1.14.0a0+44dac51c.nv23.01 ====== verbose: False, log level: Level.ERROR ======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================

Traceback (most recent call last): File "../../tools/torch2onnx.py", line 78, in torch.onnx.export(siammask, input_dict, "SiamMask_DAVIS_half_test.onnx", File "/home/x/archiconda3/envs/siammask38/lib/python3.8/site-packages/torch/onnx/utils.py", line 506, in export _export( File "/home/x/archiconda3/envs/siammask38/lib/python3.8/site-packages/torch/onnx/utils.py", line 1533, in _export graph, params_dict, torch_out = _model_to_graph( File "/home/x/archiconda3/envs/siammask38/lib/python3.8/site-packages/torch/onnx/utils.py", line 1113, in _model_to_graph graph, params, torch_out, module = _create_jit_graph(model, args) File "/home/x/archiconda3/envs/siammask38/lib/python3.8/site-packages/torch/onnx/utils.py", line 989, in _create_jit_graph graph, torch_out = _trace_and_get_graph_from_model(model, args) File "/home/x/archiconda3/envs/siammask38/lib/python3.8/site-packages/torch/onnx/utils.py", line 893, in _trace_and_get_graph_from_model trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph( File "/home/x/archiconda3/envs/siammask38/lib/python3.8/site-packages/torch/jit/_trace.py", line 1260, in _get_trace_graph outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs) File "/home/x/archiconda3/envs/siammask38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1480, in _call_impl return forward_call(*args, **kwargs) File "/home/x/archiconda3/envs/siammask38/lib/python3.8/site-packages/torch/jit/_trace.py", line 127, in forward graph, out = torch._C._create_graph_by_tracing( File "/home/x/archiconda3/envs/siammask38/lib/python3.8/site-packages/torch/jit/_trace.py", line 118, in wrapper outs.append(self.inner(*trace_inputs)) File "/home/x/archiconda3/envs/siammask38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1480, in _call_impl return forward_call(*args, **kwargs) File "/home/x/archiconda3/envs/siammask38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1467, in _slow_forward result = self.forward(*input, **kwargs) TypeError: forward() missing 1 required positional argument: 'input'

To be specific, i need help figuring out the exact set of input and output parameters for torch.onnx.export() to perform the conversion.

surajs52 avatar Feb 14 '24 11:02 surajs52