TorchPruner
TorchPruner copied to clipboard
Training time Pruning?
Hi, Thanks for sharing GREAT codes for pruning :D
I'm actually trying to prune some models(example Resnet friends) in training time.
First, I use the torch.nn.utils.prune to find and build mask for module which max parameters are below my personal threshold.
After that, I'm using prune_model for Real prune and to avoid channel issue with skip connection.
However, when I did real prune, I got some errors like this where when I'm trying to get loss from pruned network in the for loop
function normbackward1 returned an invalid gradient at index 0 - got [512, 120, 3, 3] but expected shape compatible with [512, 5 12, 3, 3] (validate_outputs at /pytorch/torch/csrc/autograd/engine.cpp:472)
Do you have any idea or example for training pruning?
Thanks in advance
Hi! Notice that when you use TorchPruner for pruning, you are actually slicing the weight tensors. This means that you also need to make sure the weights of the layers around the one being pruned should also be pruned accordingly. With resnet this is a bit tricky because of the skip connections. If would suggest to first experiment with a sequential model (VGG or similar), then move to Resnet but pruning only those layers that are not affected by skip connections.
The error you get is likely because you are pruning some weights but forgot to prune other tensors to match the new dimensions. In particular, you should check that the cascading_modules parameter is correct. I would need to see the architecture to provide more precise indications.
Thanks for fast reply!!
I agree with your idea that Resnet is tricky haha.. I'll try to make another experiments with VGG nets first 👍
I'm using Resnet18 (little bit difference with torchvision.models.resnet18 because of input image size) to prune with my personal datasets. I will doublecheck cascading_modules
Thank you!!!
Actually, when I test with the pruned model which was from prune_model,
there were no issues with channels when I tried to get prediction.
So I think there were no problems with cascading_modules as you suggested.
However, I found some interesting parts that the module has 'backward_hook'
which is not controlled by prune_model.
It makes some troubles when loss.backward().