Torch-Pruning
Torch-Pruning copied to clipboard
Matrix shape error with iterative pruning
Hello there,
I am trying to implement an iterative pruning process of the timm models.
The following code works great when iterative steps are small (i.e. up to ~30-40) but then suddenly break for higher numbers, preventing from doing longer iterative pruning.
RuntimeError: mat1 and mat2 shapes cannot be multiplied (197x760 and 768x2304)
Would you have an idea of what is going on ?
FYI I am using:
torch-pruning==1.4.3
torch==2.4.0
timm==1.0.9
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = timm.create_model("beit_base_patch16_224", pretrained=True, no_jit=True).eval().to(device)
imp = tp.importance.GroupNormImportance()
iterative_steps=50
input_size = model.default_cfg['input_size']
example_inputs = torch.randn(1, *input_size).to(device)
test_output = model(example_inputs)
ignored_layers = []
num_heads = {}
pruning_ratio_dict = {}
print("========Before pruning========")
base_macs, base_params = tp.utils.count_ops_and_params(model, example_inputs)
for m in model.modules():
#if hasattr(m, 'head'): #isinstance(m, nn.Linear) and m.out_features == model.num_classes:
if isinstance(m, nn.Linear) and m.out_features == model.num_classes:
ignored_layers.append(m)
print("Ignore classifier layer: ", m)
# Attention layers
if hasattr(m, 'num_heads'):
if hasattr(m, 'qkv'):
num_heads[m.qkv] = m.num_heads
print("Attention layer: ", m.qkv, m.num_heads)
elif hasattr(m, 'qkv_proj'):
num_heads[m.qkv_proj] = m.num_heads
pruner = tp.pruner.MetaPruner(
model,
example_inputs,
global_pruning=False,
importance=imp,
iterative_steps=iterative_steps,
pruning_ratio=0.5,
pruning_ratio_dict=pruning_ratio_dict,
num_heads=num_heads,
ignored_layers=ignored_layers,
)
for stp in range(iterative_steps):
pruner.step()
for m in model.modules():
# Attention layers
if hasattr(m, 'num_heads'):
if hasattr(m, 'qkv'):
m.num_heads = num_heads[m.qkv]
m.head_dim = m.qkv.out_features // (3 * m.num_heads)
elif hasattr(m, 'qkv_proj'):
m.num_heads = num_heads[m.qqkv_projkv]
m.head_dim = m.qkv_proj.out_features // (3 * m.num_heads)
print("========After pruning========")
test_output = model(example_inputs)
pruned_macs, pruned_params = tp.utils.count_ops_and_params(model, example_inputs)
print("MACs: %.4f G => %.4f G"%(base_macs/1e9, pruned_macs/1e9))
print("Params: %.4f M => %.4f M"%(base_params/1e6, pruned_params/1e6))
hi, did you solve this problme?
hi, did you solve this problme?
Not yet, it still seems broken