Torch-Pruning icon indicating copy to clipboard operation
Torch-Pruning copied to clipboard

tp.utils.count_ops_and_params Error

Open DRVu16 opened this issue 11 months ago • 1 comments

import torch
from vit_pytorch import ViT
import torch_pruning as tp 
import vit_pytorch.vit

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

def forward(self, x):
    B, N, C = x.shape
    x = self.norm(x)

    qkv = self.to_qkv(x).reshape(B, N, 3, self.heads, self.dim_head).permute(2, 0, 3, 1, 4)
    q, k, v = qkv.unbind(0)

    dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

    attn = self.attend(dots)
    attn = self.dropout(attn)

    out = torch.matmul(attn, v)
    out = out.transpose(1, 2).reshape(B, N, -1)
    return self.to_out(out)

    
model = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 16,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

img = torch.randn(1, 3, 256, 256)

imp = tp.importance.GroupNormImportance(p=1)


example_inputs = img
base_macs, base_params = tp.utils.count_ops_and_params(model, example_inputs)

num_heads = {}
ignored_layers = [model.mlp_head]
times = 0

for m in model.modules():
    if isinstance(m, vit_pytorch.vit.Attention):
        m.forward = forward.__get__(m, vit_pytorch.vit.Attention)
        num_heads[m.to_qkv] = m.heads
    if isinstance(m, vit_pytorch.vit.FeedForward):
        ignored_layers.append(m.net[4]) # only prune the internal layers of FFN & Attention

    pruner = tp.pruner.MetaPruner(
        model, 
        example_inputs, 
        global_pruning=False, # If False, a uniform pruning ratio will be assigned to different layers.
        importance=imp, # importance criterion for parameter selection
        pruning_ratio=0.5, # target pruning ratio
        ignored_layers=ignored_layers,
        num_heads=num_heads, # number of heads in self attention
        prune_num_heads=True, # reduce num_heads by pruning entire heads (default: False)
        prune_head_dims= False, # reduce head_dim by pruning featrues dims of each head (default: True)
        head_pruning_ratio=0.5, #args.head_pruning_ratio, # remove 50% heads, only works when prune_num_heads=True (default: 0.0)
        round_to=1
    )


for i, g in enumerate(pruner.step(interactive=True)):
    g.prune()

head_id = 0
for m in model.modules():
    if isinstance(m, vit_pytorch.vit.Attention):
        print("Head #%d"%head_id)
        print("[Before Pruning] Num Heads: %d, Head Dim: %d =>"%(m.heads, m.dim_head))
        m.num_heads = pruner.num_heads[m.to_qkv]
        m.head_dim = m.to_qkv.out_features // (3 * m.heads)
        print("[After Pruning] Num Heads: %d, Head Dim: %d"%(m.heads, m.dim_head))
        print()
        head_id+=1

print("----------------------------------------")
print("Summary:")
pruned_macs, pruned_params = tp.utils.count_ops_and_params(model, example_inputs)
print("Base MACs: %.2f G, Pruned MACs: %.2f G"%(base_macs/1e9, pruned_macs/1e9))
print("Base Params: %.2f M, Pruned Params: %.2f M"%(base_params/1e6, pruned_params/1e6))

I'm trying to prune a ViT model implemented in vit_pytorch but got following error:

Summary:
Traceback (most recent call last):
  File "C:\Users\Desktop\ViT model pruning\vit-pytorch\ViT_pruning.py", line 93, in <module>
    pruned_macs, pruned_params = tp.utils.count_ops_and_params(model, example_inputs)
  File "D:\Anaconda\envs\cvlface-env\lib\site-packages\torch\autograd\grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "D:\Anaconda\envs\cvlface-env\lib\site-packages\torch_pruning\utils\op_counter.py", line 35, in count_ops_and_params
    _ = flops_model(example_inputs)
  File "D:\Anaconda\envs\cvlface-env\lib\site-packages\torch\nn\modules\module.py", line 1212, in _call_impl
    result = forward_call(*input, **kwargs)
  File "C:\Users\Desktop\ViT model pruning\vit-pytorch\vit_pytorch\vit.py", line 123, in forward
    x = self.transformer(x)
  File "D:\Anaconda\envs\cvlface-env\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\Desktop\ViT model pruning\vit-pytorch\vit_pytorch\vit.py", line 79, in forward
    x = attn(x) + x
  File "D:\Anaconda\envs\cvlface-env\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\Desktop\ViT model pruning\vit-pytorch\ViT_pruning.py", line 13, in forward
    qkv = self.to_qkv(x).reshape(B, N, 3, self.heads, self.dim_head).permute(2, 0, 3, 1, 4)
RuntimeError: shape '[1, 65, 3, 16, 64]' is invalid for input of size 99840

It is likely to have problems after pruning. How to solve this one ? @VainF or anyone else could you help me tp solve this problem. Thank you

DRVu16 avatar Dec 30 '24 08:12 DRVu16

I think the problem is that after pruning the shape of your model have changed. My guess come from Modify static attributes or forward functions in the readme.

Cyber-Vadok avatar Feb 18 '25 21:02 Cyber-Vadok