yolov5_prune
yolov5_prune copied to clipboard
what is the reason to limit percent threshold?
in get_prune_threshold function, percent threshold is calculated.
def get_prune_threshold(model_list, percent): bn_weights = gather_bn_weights(model_list) sorted_bn = torch.sort(bn_weights)[0]
# 避免剪掉所有channel的最高阈值(每个BN层的gamma的最大值的最小值即为阈值上限)
highest_thre = []
for bnlayer in model_list.values():
highest_thre.append(bnlayer.weight.data.abs().max().item())
highest_thre = min(highest_thre)
# 找到highest_thre对应的下标对应的百分比
threshold_index = (sorted_bn == highest_thre).nonzero().squeeze()
if len(threshold_index.shape) > 0:
threshold_index = threshold_index[0]
percent_threshold = threshold_index.item() / len(bn_weights)
print('Suggested Gamma threshold should be less than {}'.format(highest_thre))
print('The corresponding prune ratio is {}, but you can set higher'.format(percent_threshold))
thre_index = int(len(sorted_bn) * percent)
thre_prune = sorted_bn[thre_index]
print('Gamma value that less than {} are set to zero'.format(thre_prune))
print("=" * 94)
print(f"|\t{'layer name':<25}{'|':<10}{'origin channels':<20}{'|':<10}{'remaining channels':<20}|")
return thre_prune
we use --percent parameter to apply prune.py like 'python prune.py --percent 0.5 --weights runs/train/coco_sparsity2/weights/last.pt --data data/coco.yaml --cfg models/yolov5s.yaml --imgsz 640' and if --percent parameter is bigger than calculated percent_threshold, it happened error.
I tried to use this github code for custom dataset training and pruning. Please let me know why percent threshold is limited, Thanks.