Scaled-YOLOv4-TensorRT icon indicating copy to clipboard operation
Scaled-YOLOv4-TensorRT copied to clipboard

你好 运行yolov4-tiny-tensorrt这个例子时,运行gen_wts.py报错

Open HeuMindFusion opened this issue 3 years ago • 1 comments

Traceback (most recent call last): File "gen_wts.py", line 13, in model.load_state_dict(torch.load(weights, map_location=device)['model']) File "/home/sany/.local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1224, in load_state_dict self.class.name, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for Darknet: Missing key(s) in state_dict: "total_ops", "total_params", "module_list.total_ops", "module_list.total_params", "module_list.0.total_ops", "module_list.0.total_params", "module_list.1.total_ops", "module_list.1.total_params", "module_list.2.total_ops", "module_list.2.total_params", "module_list.3.total_ops", "module_list.3.total_params", "module_list.4.total_ops", "module_list.4.total_params", "module_list.5.total_ops", "module_list.5.total_params", "module_list.6.total_ops", "module_list.6.total_params", "module_list.7.total_ops", "module_list.7.total_params", "module_list.8.total_ops", "module_list.8.total_params", "module_list.10.total_ops", "module_list.10.total_params", "module_list.11.total_ops", "module_list.11.total_params", "module_list.12.total_ops", "module_list.12.total_params", "module_list.13.total_ops", "module_list.13.total_params", "module_list.14.total_ops", "module_list.14.total_params", "module_list.15.total_ops", "module_list.15.total_params", "module_list.16.total_ops", "module_list.16.total_params", "module_list.18.total_ops", "module_list.18.total_params", "module_list.19.total_ops", "module_list.19.total_params", "module_list.20.total_ops", "module_list.20.total_params", "module_list.21.total_ops", "module_list.21.total_params", "module_list.22.total_ops", "module_list.22.total_params", "module_list.23.total_ops", "module_list.23.total_params", "module_list.24.total_ops", "module_list.24.total_params", "module_list.26.total_ops", "module_list.26.total_params", "module_list.27.total_ops", "module_list.27.total_params", "module_list.28.total_ops", "module_list.28.total_params", "module_list.29.total_ops", "module_list.29.total_params", "module_list.30.total_ops", "module_list.30.total_params", "module_list.31.total_ops", "module_list.31.total_params", "module_list.32.total_ops", "module_list.32.total_params", "module_list.34.total_ops", "module_list.34.total_params", "module_list.35.total_ops", "module_list.35.total_params", "module_list.36.total_ops", "module_list.36.total_params", "module_list.37.total_ops", "module_list.37.total_params".

gen_wts.py import struct import sys from models import * from utils.utils import * from utils.torch_utils import select_device model = Darknet('cfg/yolov4-tiny.cfg', (416, 416)) weights = sys.argv[1] device = select_device('cpu')

dev = '0' print(model) if weights.endswith('.pt'): # pytorch format model.load_state_dict(torch.load(weights, map_location=device)['model']) print("------------------------------") else: # darknet format load_darknet_weights(model, weights)

f = open('yolov4-tiny.wts', 'w') f.write('{}\n'.format(len(model.state_dict().keys()))) for k, v in model.state_dict().items(): vr = v.reshape(-1).cpu().numpy() f.write('{} {} '.format(k, len(vr))) for vv in vr: f.write(' ') f.write(struct.pack('>f',float(vv)).hex()) f.write('\n')

HeuMindFusion avatar Jun 10 '21 09:06 HeuMindFusion

try https://github.com/WongKinYiu/PyTorch_YOLOv4/tree/u3_preview with yolov4-tiny.pt

tjuskyzhang avatar Jun 11 '21 07:06 tjuskyzhang