Torch-Pruning
Torch-Pruning copied to clipboard
About weight mismatch
Hi @VainF ! Thank you for your work, but I have some strange problems:
After pruning on my own model, the weight of the model does not match any more.
(RuntimeError: Given groups=1, weight of size [31, 32, 3, 3], expected input[16, 31, 38, 38] to have 32 channels, but got 31 channels instead)
Therefore, i checked the weight of the model and the results are as follows:
(0): Conv2d(3, 31, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): PReLU(num_parameters=1) (2): Conv2d(32, 31, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (3): PReLU(num_parameters=1) (4): Conv2d(32, 31, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (5): PReLU(num_parameters=1) (6): Conv2d(32, 31, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (7): PReLU(num_parameters=1) (8): Conv2d(32, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (9): PReLU(num_parameters=1)
The code I used is as follows:
def pruning():
Model_saved = './Weight/'
model = torch.load(Model_saved + 'Model.pth').cpu()
DG = tp.DependencyGraph()
DG.build_dependency(model, example_inputs=torch.randn(1, 3, 38, 38).float().cpu())
for m in model.modules():
if isinstance(m, nn.Conv2d):
mode = tp.prune_conv
elif isinstance(m, nn.Linear):
mode = tp.prune_linear
elif isinstance(m, nn.BatchNorm2d):
mode = tp.prune_batchnorm
else:
continue
weight = m.weight.detach().cpu().numpy()
out_channels = weight.shape[0]
L1_norm = np.sum(np.abs(weight))
num_pruned = int(out_channels * 0.2)
prune_index = np.argsort(L1_norm)[:num_pruned].tolist()
pruning_plan = DG.get_pruning_plan(m, mode, idxs=prune_index)
print(pruning_plan)
pruning_plan.exec()
return model
In fact, all the Inception blocks in my model encountered this problem. I'm not sure about the specific reasons. I hope I can get your advice, thank you for your help.
I faced a similar issue on a different architecture, check the number of channels you're passing as input, I think your input should be [16,32,38,38]
Hi @cyf666-coder , Any progress on this issue?