Erroneous pruning of Swin Transformers when prune_hf_swin.py is used with pruning ratios < 0.5
I was experimenting with prune_hf_swin.py to analyse the impact of various pruning ratios on the performance of the model. However, I always get errors if I use pruning ratios less than 0.5 and nothing else changed in the prune_hf_swin.py code. For example: the following is the error when pruning ratio is 0.2.
RuntimeError: Given normalized_shape=[304], expected input with shape [*, 304], but got input of size[1, 784, 224]
But there are no errors when pruning ratios above 0.5 are used. Is this related to the way SwinPatchMergingPruner function is being implmented?
If anyone has idea on the reason behind the above error then it would be useful. Thanks in advance.
Thanks for raising the issue! I may need a few days to investigate the bug.