Torch-Pruning icon indicating copy to clipboard operation
Torch-Pruning copied to clipboard

About weight mismatch

Open cyf666-coder opened this issue 4 years ago • 2 comments

Hi @VainF ! Thank you for your work, but I have some strange problems: After pruning on my own model, the weight of the model does not match any more. (RuntimeError: Given groups=1, weight of size [31, 32, 3, 3], expected input[16, 31, 38, 38] to have 32 channels, but got 31 channels instead) Therefore, i checked the weight of the model and the results are as follows: (0): Conv2d(3, 31, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): PReLU(num_parameters=1) (2): Conv2d(32, 31, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (3): PReLU(num_parameters=1) (4): Conv2d(32, 31, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (5): PReLU(num_parameters=1) (6): Conv2d(32, 31, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (7): PReLU(num_parameters=1) (8): Conv2d(32, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (9): PReLU(num_parameters=1)

The code I used is as follows:

def pruning():
    Model_saved = './Weight/'
    model = torch.load(Model_saved + 'Model.pth').cpu()

    DG = tp.DependencyGraph()
    DG.build_dependency(model, example_inputs=torch.randn(1, 3, 38, 38).float().cpu())

    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            mode = tp.prune_conv
        elif isinstance(m, nn.Linear):
            mode = tp.prune_linear
        elif isinstance(m, nn.BatchNorm2d):
            mode = tp.prune_batchnorm
        else:
            continue

        weight = m.weight.detach().cpu().numpy()
        out_channels = weight.shape[0]
        L1_norm = np.sum(np.abs(weight))
        num_pruned = int(out_channels * 0.2)
        prune_index = np.argsort(L1_norm)[:num_pruned].tolist()
        pruning_plan = DG.get_pruning_plan(m, mode, idxs=prune_index)
        print(pruning_plan)
        pruning_plan.exec()
    return model

In fact, all the Inception blocks in my model encountered this problem. I'm not sure about the specific reasons. I hope I can get your advice, thank you for your help.

cyf666-coder avatar May 18 '21 06:05 cyf666-coder

I faced a similar issue on a different architecture, check the number of channels you're passing as input, I think your input should be [16,32,38,38]

Aarsh2001 avatar May 22 '21 17:05 Aarsh2001

Hi @cyf666-coder , Any progress on this issue?

edwardnguyen1705 avatar Mar 02 '22 09:03 edwardnguyen1705