Person_reID_baseline_pytorch icon indicating copy to clipboard operation
Person_reID_baseline_pytorch copied to clipboard

能不能提供一个 onnx_export.py方便将模型转为onnx,以便再转为NCNN等框架可推理的模型

Open lmq5294249 opened this issue 1 year ago • 2 comments

我尝试将模型转为onnx出现错误,无法解决

lmq5294249 avatar Apr 18 '23 02:04 lmq5294249

你好 有具体错误可以贴一下么?感谢!

layumi avatar Apr 18 '23 08:04 layumi

train.py文件save_network方法改成这样

def save_network(network, epoch_label):
    save_filename = 'net_%s.pth'% epoch_label
    save_path = os.path.join('./model',name,save_filename)
    # torch.save(network.cpu().state_dict(), save_path)
    # 上面注释的部分改成下面的
    torch.save(network, save_path)
    if torch.cuda.is_available():
        network.cuda(gpu_ids[0])

然后新建一个py文件,内容如下(其中输入模型和输出模型路径改成自己的):

import torch
import torch.nn
import onnx
 
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')
 # 路径改成训练输出模型的位置
model = torch.load('/project/train/src_repo/model/ft_ResNet50/net_last.pth', map_location=device)
model.eval()
 
input_names = ['input']
output_names = ['output']
 
x = torch.randn(1, 3, 224, 224, device=device)
 # 路径改为转换onnx模型的位置
torch.onnx.export(model, x, '/project/train/src_repo/model/ft_ResNet50/net_last.onnx', input_names=input_names, output_names=output_names, verbose='True')

CaptainJi avatar May 04 '23 02:05 CaptainJi