faster-rcnn.pytorch
faster-rcnn.pytorch copied to clipboard
Try to convert pth to onnx. But error
I wanna convert pth to onnx format. This is my code:
import torch from model.faster_rcnn.vgg16 import vgg16 from model.faster_rcnn.resnet import resnet import numpy as np from torch.autograd import Variable
def load_model(model, pretrained_path):
print('Loading pretrained model from {}'.format(pretrained_path))
pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage)
model.load_state_dict(pretrained_dict, strict=False)
return model
output_onnx = './output.onnx' raw_weights = './faster_rcnn_1_10_2504.pth' pascal_classes = np.asarray(['background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'])
load weight
net = resnet(pascal_classes, 101, pretrained=False, class_agnostic=False)
net.create_architecture()
checkpoint = torch.load(raw_weights)
for k in checkpoint.keys():
print(k)
net.load_state_dict(checkpoint['model'])
initilize the tensor holder here.
im_data = torch.FloatTensor(1) im_info = torch.FloatTensor(1) num_boxes = torch.LongTensor(1) gt_boxes = torch.FloatTensor(1)
ship to cuda
im_data = im_data.cuda() im_info = im_info.cuda() num_boxes = num_boxes.cuda() gt_boxes = gt_boxes.cuda()
make variable
im_data = Variable(im_data, volatile=True) im_info = Variable(im_info, volatile=True) num_boxes = Variable(num_boxes, volatile=True) gt_boxes = Variable(gt_boxes, volatile=True)
net.eval() print('Finished loading model!') device = torch.device("cuda") net = net.to(device)
input_names = ["input0"] output_names = ["output0"] inputs = torch.randn(1, 3, 300, 300).to(device)
output model
torch_out = torch.onnx.export(net, inputs, output_onnx, export_params=True, verbose=False,keep_initializers_as_inputs=True, input_names=input_names, output_names=output_names)
but when I run it, I got this error. How can I fix it?Thanks!
Traceback (most recent call last):
File "pth2onnx.py", line 67, in