yolov5prune icon indicating copy to clipboard operation
yolov5prune copied to clipboard

剪枝时剩余通道为1会报too many indices for tensor of dimension 3

Open HandsLing opened this issue 3 years ago • 2 comments

具体错误如下: #这是我把out_idx打印出来,前面有一个np.unsqueeze(),如果剩余通道为1那个用这个函数连这个1都没了 out_idx.shape: () layer.weight.size(): torch.Size([128, 512, 1, 1]) Traceback (most recent call last): File "prune.py", line 827, in opt=opt File "prune.py", line 506, in test_prune pruned_layer.weight.data = w[:,formerin, :, :].clone() IndexError: too many indices for tensor of dimension 3

HandsLing avatar Sep 08 '21 02:09 HandsLing

你好,我也出现这个问题,请问你解决了吗

nrikoh avatar Sep 13 '21 14:09 nrikoh

解决了,在剪枝的时候判断剪枝的mask中1的个数是否为1,若为1则逐渐减小thresh以保证mask中1的个数大于1,相当于剩余通道不为1了,以下为代码: prune_utils.py中 def obtain_bn_mask(bn_module, thre):

thre = thre.cuda()
mask = bn_module.weight.data.abs().ge(thre).float()
ones = (mask == 1.).sum().sum()
while ones<=1:
    thre -= 0.005
    mask = bn_module.weight.data.abs().ge(thre).float()
    ones = (mask == 1.).sum().sum() 
return mask

HandsLing avatar Sep 14 '21 01:09 HandsLing