CRNN_Chinese_Characters_Rec
CRNN_Chinese_Characters_Rec copied to clipboard
训练出来的模型怎么转ONNX?GIT主?
@Sierkinhane
同问,使用脚本转换时:
python3 pytorch2onnx.py
Traceback (most recent call last):
File "pytorch2onnx.py", line 32, in
训练时候用的什么模型,你就加载什么模型
我用crnn_pytorch CRNN_Chinese_Characters_Rec训练出来的模型 pth,但是转ONNX一直不成功,
@.***
发件人: xiao12mm 发送时间: 2021-06-17 10:52 收件人: Sierkinhane/CRNN_Chinese_Characters_Rec 抄送: lwx; Comment 主题: Re: [Sierkinhane/CRNN_Chinese_Characters_Rec] 训练出来的模型怎么转ONNX?GIT主? (#280) 训练时候用的什么模型,你就加载什么模型 — You are receiving this because you commented. Reply to this email directly, view it on GitHub, or unsubscribe.
Error(s) in loading state_dict for CRNN: 加载模型参数的时候就报错了,训练的时候用的什么网络,加载的时候要加载相同的网络,这样参数才对的上
https://github.com/YIYANGCAI/CRNN-Pytorch2TensorRT-via-ONNX,我用的这个pytorch2onnx.py,模型我对比过了,一样的,但是加载参数的时候还是有很大差别, 这个问题是可以解决的,Unexpected key(s) in state_dict: "state_dict", "epoch", "best_acc". 但是用checkpoint = torch.load(modelpath) #modelpath是你要加载训练好的模型文件地址 model.load_state_dict(checkpoint['state_dict']) 还是出现问题的呢 model.load_state_dict(checkpoint['state_dict']) File "/usr/local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 845, in load_state_dict self.class.name, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for CRNN: size mismatch for rnn.0.rnn.weight_ih_l0: copying a param with shape torch.Size([1024, 512]) from checkpoint, the shape in current model is torch.Size([1120, 512]). size mismatch for rnn.0.rnn.weight_hh_l0: copying a param with shape torch.Size([1024, 256]) from checkpoint, the shape in current model is torch.Size([1120, 280]). size mismatch for rnn.0.rnn.bias_ih_l0: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1120]). size mismatch for rnn.0.rnn.bias_hh_l0: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1120]). size mismatch for rnn.0.rnn.weight_ih_l0_reverse: copying a param with shape torch.Size([1024, 512]) from checkpoint, the shape in current model is torch.Size([1120, 512]). size mismatch for rnn.0.rnn.weight_hh_l0_reverse: copying a param with shape torch.Size([1024, 256]) from checkpoint, the shape in current model is torch.Size([1120, 280]).