yolov5prune
yolov5prune copied to clipboard
剪枝时剩余通道为1会报too many indices for tensor of dimension 3
具体错误如下:
#这是我把out_idx打印出来,前面有一个np.unsqueeze(),如果剩余通道为1那个用这个函数连这个1都没了
out_idx.shape: ()
layer.weight.size(): torch.Size([128, 512, 1, 1])
Traceback (most recent call last):
File "prune.py", line 827, in
你好,我也出现这个问题,请问你解决了吗
解决了,在剪枝的时候判断剪枝的mask中1的个数是否为1,若为1则逐渐减小thresh以保证mask中1的个数大于1,相当于剩余通道不为1了,以下为代码: prune_utils.py中 def obtain_bn_mask(bn_module, thre):
thre = thre.cuda()
mask = bn_module.weight.data.abs().ge(thre).float()
ones = (mask == 1.).sum().sum()
while ones<=1:
thre -= 0.005
mask = bn_module.weight.data.abs().ge(thre).float()
ones = (mask == 1.).sum().sum()
return mask