Torch-Pruning
Torch-Pruning copied to clipboard
tp.utils.count_ops_and_params Error
import torch
from vit_pytorch import ViT
import torch_pruning as tp
import vit_pytorch.vit
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
def forward(self, x):
B, N, C = x.shape
x = self.norm(x)
qkv = self.to_qkv(x).reshape(B, N, 3, self.heads, self.dim_head).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = out.transpose(1, 2).reshape(B, N, -1)
return self.to_out(out)
model = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 16,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
img = torch.randn(1, 3, 256, 256)
imp = tp.importance.GroupNormImportance(p=1)
example_inputs = img
base_macs, base_params = tp.utils.count_ops_and_params(model, example_inputs)
num_heads = {}
ignored_layers = [model.mlp_head]
times = 0
for m in model.modules():
if isinstance(m, vit_pytorch.vit.Attention):
m.forward = forward.__get__(m, vit_pytorch.vit.Attention)
num_heads[m.to_qkv] = m.heads
if isinstance(m, vit_pytorch.vit.FeedForward):
ignored_layers.append(m.net[4]) # only prune the internal layers of FFN & Attention
pruner = tp.pruner.MetaPruner(
model,
example_inputs,
global_pruning=False, # If False, a uniform pruning ratio will be assigned to different layers.
importance=imp, # importance criterion for parameter selection
pruning_ratio=0.5, # target pruning ratio
ignored_layers=ignored_layers,
num_heads=num_heads, # number of heads in self attention
prune_num_heads=True, # reduce num_heads by pruning entire heads (default: False)
prune_head_dims= False, # reduce head_dim by pruning featrues dims of each head (default: True)
head_pruning_ratio=0.5, #args.head_pruning_ratio, # remove 50% heads, only works when prune_num_heads=True (default: 0.0)
round_to=1
)
for i, g in enumerate(pruner.step(interactive=True)):
g.prune()
head_id = 0
for m in model.modules():
if isinstance(m, vit_pytorch.vit.Attention):
print("Head #%d"%head_id)
print("[Before Pruning] Num Heads: %d, Head Dim: %d =>"%(m.heads, m.dim_head))
m.num_heads = pruner.num_heads[m.to_qkv]
m.head_dim = m.to_qkv.out_features // (3 * m.heads)
print("[After Pruning] Num Heads: %d, Head Dim: %d"%(m.heads, m.dim_head))
print()
head_id+=1
print("----------------------------------------")
print("Summary:")
pruned_macs, pruned_params = tp.utils.count_ops_and_params(model, example_inputs)
print("Base MACs: %.2f G, Pruned MACs: %.2f G"%(base_macs/1e9, pruned_macs/1e9))
print("Base Params: %.2f M, Pruned Params: %.2f M"%(base_params/1e6, pruned_params/1e6))
I'm trying to prune a ViT model implemented in vit_pytorch but got following error:
Summary:
Traceback (most recent call last):
File "C:\Users\Desktop\ViT model pruning\vit-pytorch\ViT_pruning.py", line 93, in <module>
pruned_macs, pruned_params = tp.utils.count_ops_and_params(model, example_inputs)
File "D:\Anaconda\envs\cvlface-env\lib\site-packages\torch\autograd\grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "D:\Anaconda\envs\cvlface-env\lib\site-packages\torch_pruning\utils\op_counter.py", line 35, in count_ops_and_params
_ = flops_model(example_inputs)
File "D:\Anaconda\envs\cvlface-env\lib\site-packages\torch\nn\modules\module.py", line 1212, in _call_impl
result = forward_call(*input, **kwargs)
File "C:\Users\Desktop\ViT model pruning\vit-pytorch\vit_pytorch\vit.py", line 123, in forward
x = self.transformer(x)
File "D:\Anaconda\envs\cvlface-env\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "C:\Users\Desktop\ViT model pruning\vit-pytorch\vit_pytorch\vit.py", line 79, in forward
x = attn(x) + x
File "D:\Anaconda\envs\cvlface-env\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "C:\Users\Desktop\ViT model pruning\vit-pytorch\ViT_pruning.py", line 13, in forward
qkv = self.to_qkv(x).reshape(B, N, 3, self.heads, self.dim_head).permute(2, 0, 3, 1, 4)
RuntimeError: shape '[1, 65, 3, 16, 64]' is invalid for input of size 99840
It is likely to have problems after pruning. How to solve this one ? @VainF or anyone else could you help me tp solve this problem. Thank you
I think the problem is that after pruning the shape of your model have changed. My guess come from Modify static attributes or forward functions in the readme.