HitNet icon indicating copy to clipboard operation
HitNet copied to clipboard

How to convert pth model to onnx?

Open meifannao opened this issue 1 year ago • 1 comments

My python code for converting the model is as follows, but I don't get the onnx model

import torch
import torch.onnx
from lib.pvt import Hitnet

def pth_to_onnx(input, checkpoint, onnx_path, input_names=['input'], output_names=['output1', 'output2'], device='cuda'):
    if not onnx_path.endswith('.onnx'):
        print('Warning! The onnx model name is not correct,\
              please give a name that ends with \'.onnx\'!')
        return 0
    print(torch.cuda.is_available())
    model = Hitnet().cuda() 
    model.load_state_dict(torch.load(checkpoint))
    model.eval()
    
    torch.onnx.export(model, input, onnx_path, verbose=False, input_names=input_names, output_names=output_names) 
    print("Exporting .pth model to onnx model has been successful!")

if __name__ == '__main__':
    checkpoint = './Net_epoch_best.pth'
    onnx_path = './Net_epoch_best.onnx'
    input = torch.randn(1, 3, 480, 640, device='cuda')
    pth_to_onnx(input, checkpoint, onnx_path)

meifannao avatar Aug 04 '23 09:08 meifannao

we don't implement the onnx conversion, but only provide a onnx template

HUuxiaobin avatar Dec 11 '23 02:12 HUuxiaobin