Torch-Pruning
Torch-Pruning copied to clipboard
Yolov8 pruning with Taylor criteria
Hi, I'm trying to experiment with some criteria for yolov8 pruning with Taylor criteria. However, when I set the global pruning is True, the pruner doesn't prune. It only works with the local pruning. Can sb help me with this. Here my summary code:
...
def calculate_grad(model, **args):
# save temporary model.args
tmp_args = model.args
args = get_cfg(DEFAULT_CFG, args['args'])
device = select_device(device=args.device, batch=0)
model.args = args
model.to(device)
def criterion(preds, batch):
"""Compute loss for YOLO prediction and ground-truth."""
compute_loss = Loss(de_parallel(model))
return compute_loss(preds, batch)
def preprocess_batch(batch, device): #preprocess_batch trong từng ảnh
"""Preprocesses a batch of images by scaling and converting to float."""
batch['img'] = batch['img'].to(device, non_blocking=True).float() / 255
return batch
data = check_det_dataset(args.data)
trainset = data['train']
gs = max(int(de_parallel(model).stride.max() if model else 0), 32)
train_loader = build_dataloader(args, args.batch, img_path=trainset, stride=gs, rank=RANK, mode='train',
rect=False, data_info=data)[0]
nb = len(train_loader)
print('--------Calculate grad start--------')
pbar = tqdm(enumerate(train_loader), total=nb, bar_format=TQDM_BAR_FORMAT)
scaler = amp.GradScaler()
for i, batch in pbar:
with torch.cuda.amp.autocast():
batch = preprocess_batch(batch, device)
preds = model(batch['img'])
loss, _ = criterion(preds, batch) # cal loss
scaler.scale(loss).backward()
print('--------Calculate grad end---------')
# return back model.args
model.args = tmp_args
model.to('cpu')
...
pruner = tp.pruner.MetaPruner(
model.model,
example_inputs,
global_pruning = True, # additional test
importance=tp.importance.GroupTaylorImportance(),
iterative_steps=1,
ch_sparsity=ch_sparsity,
ignored_layers=ignored_layers,
unwrapped_parameters=unwrapped_parameters
)
calculate_grad(model.model, args=deepcopy(pruning_cfg))
pruner.step()
Link full code: https://github.com/minhhotboy9x/ultralytics_YOLOv8_custom/blob/main/benchmarks/prunability/yolov8_pruning_taylor.py