Torch-Pruning icon indicating copy to clipboard operation
Torch-Pruning copied to clipboard

Matrix shape error with iterative pruning

Open AntSim12 opened this issue 1 year ago • 2 comments

Hello there,

I am trying to implement an iterative pruning process of the timm models.

The following code works great when iterative steps are small (i.e. up to ~30-40) but then suddenly break for higher numbers, preventing from doing longer iterative pruning.

RuntimeError: mat1 and mat2 shapes cannot be multiplied (197x760 and 768x2304)

Would you have an idea of what is going on ?

FYI I am using:

torch-pruning==1.4.3
torch==2.4.0
timm==1.0.9

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = timm.create_model("beit_base_patch16_224", pretrained=True, no_jit=True).eval().to(device)

imp = tp.importance.GroupNormImportance()
iterative_steps=50
    
input_size = model.default_cfg['input_size']
example_inputs = torch.randn(1, *input_size).to(device)
test_output = model(example_inputs)
ignored_layers = []
num_heads = {}
pruning_ratio_dict = {}
print("========Before pruning========")
base_macs, base_params = tp.utils.count_ops_and_params(model, example_inputs)


for m in model.modules():
    #if hasattr(m, 'head'): #isinstance(m, nn.Linear) and m.out_features == model.num_classes:
    if isinstance(m, nn.Linear) and m.out_features == model.num_classes:
        ignored_layers.append(m)
        print("Ignore classifier layer: ", m)

    # Attention layers
    if hasattr(m, 'num_heads'):
        if hasattr(m, 'qkv'):
            num_heads[m.qkv] = m.num_heads
            print("Attention layer: ", m.qkv, m.num_heads)
        elif hasattr(m, 'qkv_proj'):
            num_heads[m.qkv_proj] = m.num_heads


pruner = tp.pruner.MetaPruner(
                model, 
                example_inputs, 
                global_pruning=False,
                importance=imp,
                iterative_steps=iterative_steps, 
                pruning_ratio=0.5,
                pruning_ratio_dict=pruning_ratio_dict,
                num_heads=num_heads,
                ignored_layers=ignored_layers,
            )
for stp in range(iterative_steps):
    pruner.step()


    for m in model.modules():
        # Attention layers
        if hasattr(m, 'num_heads'):
            if hasattr(m, 'qkv'):
                m.num_heads = num_heads[m.qkv]
                m.head_dim = m.qkv.out_features // (3 * m.num_heads)
            elif hasattr(m, 'qkv_proj'):
                m.num_heads = num_heads[m.qqkv_projkv]
                m.head_dim = m.qkv_proj.out_features // (3 * m.num_heads)

    print("========After pruning========")
    test_output = model(example_inputs)
    pruned_macs, pruned_params = tp.utils.count_ops_and_params(model, example_inputs)
    print("MACs: %.4f G => %.4f G"%(base_macs/1e9, pruned_macs/1e9))
    print("Params: %.4f M => %.4f M"%(base_params/1e6, pruned_params/1e6))

AntSim12 avatar Oct 23 '24 14:10 AntSim12

hi, did you solve this problme?

hellominjeong avatar Nov 27 '24 07:11 hellominjeong

hi, did you solve this problme?

Not yet, it still seems broken

AntSim12 avatar Dec 20 '24 10:12 AntSim12