Torch-Pruning
Torch-Pruning copied to clipboard
Initialize pruner in YOLOv8_pruning.py make pruning process flaw
I see in yolov8_pruning.py
prune YOLOv8 that it always has to initialize a pruner and delete it repeatedly. When I try modifying this file, I initialize the pruner outside the iterative loop and I see that after the first normal iterative pruning loop, the model isn't pruned anymore. Can somebody explain this, please?
ignored_layers = []
unwrapped_parameters = []
for m in model.model.modules():
if isinstance(m, (Detect,)):
ignored_layers.append(m)
pruner = tp.pruner.MagnitudePruner(
model.model,
example_inputs,
importance = tp.importance.MagnitudeImportance(p=2), # L2 norm pruning,
iterative_steps = args.iterative_steps,
ch_sparsity = args.target_prune_rate,
ignored_layers = ignored_layers,
unwrapped_parameters = unwrapped_parameters
)
for i in range(args.iterative_steps):
model.model.train()
for name, param in model.model.named_parameters():
param.requires_grad = True
pruner.step() # remove some weights with lowest importance
# print(f'-------------{type(pruner.model)}-------------------')
# print(f'-------------{type(model.device)}-------------------')
example_inputs = example_inputs.to(model.device)
# pre fine-tuning validation
pruning_cfg['name'] = os.path.join(prefix_folder, f"step_{i}_pre_val")
pruning_cfg['batch'] = 1
validation_model.model = deepcopy(model.model)
metric = validation_model.val(**pruning_cfg)
pruned_map = metric.box.map
pruned_macs, pruned_nparams = tp.utils.count_ops_and_params(model.model, example_inputs)
current_speed_up = float(macs_list[0]) / pruned_macs
print(f"After pruning iter {i + 1}: MACs={pruned_macs / 1e9} G, #Params={pruned_nparams / 1e6} M, "
f"mAP={pruned_map}, speed up={current_speed_up}")
# fine-tuning
for name, param in model.model.named_parameters():
param.requires_grad = True
pruning_cfg['name'] = os.path.join(prefix_folder, f"step_{i}_finetune")
pruning_cfg['batch'] = batch_size # restore batch size
model.train_v2(pruning=True, **pruning_cfg)
# post fine-tuning validation
pruning_cfg['name'] = os.path.join(prefix_folder, f"step_{i}_post_val")
pruning_cfg['batch'] = 1
validation_model = YOLO(model.trainer.best)
metric = validation_model.val(**pruning_cfg)
current_map = metric.box.map
print(f"After fine tuning mAP={current_map}")
macs_list.append(pruned_macs)
nparams_list.append(pruned_nparams / base_nparams * 100)
pruned_map_list.append(pruned_map)
map_list.append(current_map)