YOLOv5-Multibackbone-Compression
YOLOv5-Multibackbone-Compression copied to clipboard
Use ConvTranspose2d instead of Upsample
@Gumpest 你好,支持把yolov5s-pruning.yaml中的nn.Upsample替换成nn.ConvTranspose2d进行prune吗?我进行替换后,按照给的文档训练一遍模型后,运行pruneEagleEye.py报错:
File "/root/.pycharm_helpers/pydev/pydevd.py", line 1491, in _exec
pydev_imports.execfile(file, globals, locals) # execute the script
File "/root/.pycharm_helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
exec(compile(contents+"\n", file, 'exec'), glob, loc)
File "/data/yolov5/pruneEagleEye.py", line 153, in <module>
rand_prune_and_eval(model, ignore_idx, opt)
File "/data/yolov5/pruneEagleEye.py", line 65, in rand_prune_and_eval
compact_model = Model(pruned_yaml, pruning=False).to(device)
File "/data/yolov5/models/yolo_prune.py", line 325, in __init__
m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))]) # forward
File "/data/yolov5/models/yolo_prune.py", line 324, in <lambda>
forward = lambda x: self.forward(x)[0] if isinstance(m, Segment) else self.forward(x)
File "/data/yolov5/models/yolo_prune.py", line 340, in forward
return self._forward_once(x, profile, visualize) # single-scale inference, train
File "/data/yolov5/models/yolo_prune.py", line 250, in _forward_once
x = m(x) # run
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/container.py", line 139, in forward
python-BaseException
input = module(input)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/data/yolov5/models/common.py", line 805, in forward
return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/data/yolov5/models/common.py", line 90, in forward
return self.act(self.bn(self.conv(x)))
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 443, in forward
return self._conv_forward(input, self.weight, self.bias)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 439, in _conv_forward
return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Given groups=1, weight of size [64, 56, 1, 1], expected input[1, 128, 32, 32] to have 56 channels, but got 128 channels instead
这是训练的结构:
from n params module arguments
0 -1 1 3520 models.common.Conv [3, 32, 6, 2, 2, 1, True]
1 -1 1 18560 models.common.Conv [32, 64, 3, 2, None, 1, True]
2 -1 1 18816 models.common.C3_prune [64, 64, 64, 1, True, 1, [0.5, 0.5], [1.0, 1.0, 1.0]]
3 -1 1 73984 models.common.Conv [64, 128, 3, 2, None, 1, True]
4 -1 2 231424 models.common.C3_prune [128, 128, 128, 2, True, 1, [0.5, 0.5], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]]
5 -1 1 295424 models.common.Conv [128, 256, 3, 2, None, 1, True]
6 -1 3 1875456 models.common.C3_prune [256, 256, 256, 3, True, 1, [0.5, 0.5], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]]
7 -1 1 1180672 models.common.Conv [256, 512, 3, 2, None, 1, True]
8 -1 1 1182720 models.common.C3_prune [512, 512, 512, 1, True, 1, [0.5, 0.5], [1.0, 1.0, 1.0]]
9 -1 1 656896 models.common.SPPF_prune [512, 512, 5, 0.5]
10 -1 1 131584 models.common.Conv [512, 256, 1, 1, None, 1, True]
11 -1 1 262400 torch.nn.modules.conv.ConvTranspose2d [256, 256, 2, 2, 0]
12 [-1, 6] 1 0 models.common.Concat [1]
13 -1 1 361984 models.common.C3_prune [512, 256, 256, 1, False, 1, [0.5, 0.5], [1.0, 1.0, 1.0]]
14 -1 1 33024 models.common.Conv [256, 128, 1, 1, None, 1, True]
15 -1 1 65664 torch.nn.modules.conv.ConvTranspose2d [128, 128, 2, 2, 0]
16 [-1, 4] 1 0 models.common.Concat [1]
17 -1 1 90880 models.common.C3_prune [256, 128, 128, 1, False, 1, [0.5, 0.5], [1.0, 1.0, 1.0]]
18 -1 1 147712 models.common.Conv [128, 128, 3, 2, None, 1, True]
19 [-1, 14] 1 0 models.common.Concat [1]
20 -1 1 296448 models.common.C3_prune [256, 256, 256, 1, False, 1, [0.5, 0.5], [1.0, 1.0, 1.0]]
21 -1 1 590336 models.common.Conv [256, 256, 3, 2, None, 1, True]
22 [-1, 10] 1 0 models.common.Concat [1]
23 -1 1 1182720 models.common.C3_prune [512, 512, 512, 1, False, 1, [0.5, 0.5], [1.0, 1.0, 1.0]]
24 [17, 20, 23] 1 229245 models.yolo_prune.Detect [80, [[10, 13, 16, 30, 33, 23], [30, 61, 62, 45, 59, 119], [116, 90, 156, 198, 373, 326]], [128, 256, 512]]
下面是搜索最优子网结构:
from n params module arguments
0 -1 1 2640 models.common.Conv [3, 24, 6, 2, 2, 1, True]
1 -1 1 6976 models.common.Conv [24, 32, 3, 2, None, 1, True]
2 -1 1 9440 models.common.C3_prune [32, 40, 64, 1, True, 1, [0.5, 0.375], [0.5, 1.0, 1.0]]
3 -1 1 20272 models.common.Conv [40, 56, 3, 2, None, 1, True]
4 -1 2 212992 models.common.C3_prune [56, 128, 128, 2, True, 1, [0.5, 0.5], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]]
5 -1 1 230800 models.common.Conv [128, 200, 3, 2, None, 1, True]
6 -1 3 1832448 models.common.C3_prune [200, 256, 256, 3, True, 1, [0.5, 0.5], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]]
7 -1 1 571888 models.common.Conv [256, 248, 3, 2, None, 1, True]
8 -1 1 488832 models.common.C3_prune [248, 392, 512, 1, True, 1, [0.5, 0.484375], [0.25, 1.0, 1.0]]
9 -1 1 164638 models.common.SPPF_prune [392, 512, 5, 0.171875]
10 -1 1 94576 models.common.Conv [512, 184, 1, 1, None, 1, True]
11 -1 1 188672 torch.nn.modules.conv.ConvTranspose2d [184, 256, 2, 2, 0]
12 [-1, 6] 1 0 models.common.Concat [1]
13 -1 1 298048 models.common.C3_prune [512, 104, 256, 1, False, 1, [0.5, 0.34375], [1.0, 1.0, 1.0]]
14 -1 1 9328 models.common.Conv [104, 88, 1, 1, None, 1, True]
15 -1 1 45184 torch.nn.modules.conv.ConvTranspose2d [88, 128, 2, 2, 0]
16 [-1, 4] 1 0 models.common.Concat [1]
17 -1 1 72240 models.common.C3_prune [256, 88, 128, 1, False, 1, [0.5, 0.3125], [0.875, 1.0, 1.0]]
18 -1 1 50816 models.common.Conv [88, 64, 3, 2, None, 1, True]
19 [-1, 14] 1 0 models.common.Concat [1]
20 -1 1 111984 models.common.C3_prune [152, 96, 256, 1, False, 1, [0.5, 0.28125], [0.375, 1.0, 1.0]]
21 -1 1 214768 models.common.Conv [96, 248, 3, 2, None, 1, True]
22 [-1, 10] 1 0 models.common.Concat [1]
23 -1 1 610000 models.common.C3_prune [432, 304, 512, 1, False, 1, [0.5, 0.40625], [0.40625, 1.0, 1.0]]
24 [17, 20, 23] 1 125205 models.yolo_prune.Detect [80, [[10, 13, 16, 30, 33, 23], [30, 61, 62, 45, 59, 119], [116, 90, 156, 198, 373, 326]], [88, 96, 304]]