HitNet
HitNet copied to clipboard
How to convert pth model to onnx?
My python code for converting the model is as follows, but I don't get the onnx model
import torch
import torch.onnx
from lib.pvt import Hitnet
def pth_to_onnx(input, checkpoint, onnx_path, input_names=['input'], output_names=['output1', 'output2'], device='cuda'):
if not onnx_path.endswith('.onnx'):
print('Warning! The onnx model name is not correct,\
please give a name that ends with \'.onnx\'!')
return 0
print(torch.cuda.is_available())
model = Hitnet().cuda()
model.load_state_dict(torch.load(checkpoint))
model.eval()
torch.onnx.export(model, input, onnx_path, verbose=False, input_names=input_names, output_names=output_names)
print("Exporting .pth model to onnx model has been successful!")
if __name__ == '__main__':
checkpoint = './Net_epoch_best.pth'
onnx_path = './Net_epoch_best.onnx'
input = torch.randn(1, 3, 480, 640, device='cuda')
pth_to_onnx(input, checkpoint, onnx_path)
we don't implement the onnx conversion, but only provide a onnx template