SiamMask
SiamMask copied to clipboard
Need help with conversion to ONNX
Hey @foolwood, i need help with conversion to ONNX format. My python script using torch.onnx.export() for conversion looks like this:
from tools.test import * #from siammask.models import Custom from custom import Custom
parser = argparse.ArgumentParser(description='PyTorch Tracking Demo')
parser.add_argument('--resume', default='', type=str, required=True, metavar='PATH',help='path to latest checkpoint (default: none)') parser.add_argument('--config', dest='config', default='config_davis.json', help='hyper-parameter of SiamMask in json format') #parser.add_argument('--base_path', default='../../data/tennis', help='datasets') #parser.add_argument('--cpu', action='store_true', help='cpu mode') args = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') torch.backends.cudnn.benchmark = True
cfg = load_config(args) siammask = Custom(anchors=cfg['anchors'])
siammask.load_state_dict(torch.load('SiamMask_DAVIS.pth')["state_dict"])
siammask.eval().to(device) siammask.half()
template = torch.randn(1, 3, 127, 127).to(device).half() search = torch.randn(1, 3, 255, 255).to(device).half() label_cls = torch.randn(1, 1, 5).to(device).half() input_dict = {'template': template, 'search': search} #, 'label_cls': label_cls}
torch.onnx.export(siammask, input_dict, "SiamMask_DAVIS_half_test.onnx", input_names=['template', 'search'], opset_version=11, do_constant_folding=True, verbose=True, output_names=['rpn_pred_cls', 'rpn_pred_loc', 'pred_mask'], dynamic_axes={'search': {0: 'batch_size'}, # if you want batch size to be dynamic 'rpn_pred_cls': {0: 'batch_size'}, 'rpn_pred_loc': {0: 'batch_size'}, 'pred_mask': {0: 'batch_size'}})
The output looks like this: [2024-02-14 15:40:21,552-rk0-features.py# 66] Current training 0 layers:
[2024-02-14 15:40:21,554-rk0-features.py# 66] Current training 1 layers:
====== Diagnostic Run torch.onnx.export version 1.14.0a0+44dac51c.nv23.01 ====== verbose: False, log level: Level.ERROR ======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================
Traceback (most recent call last):
File "../../tools/torch2onnx.py", line 78, in
To be specific, i need help figuring out the exact set of input and output parameters for torch.onnx.export() to perform the conversion.