torchprune
torchprune copied to clipboard
Discussion on the role of some classes
Hi,
I'm trying to understand the structure of the core pruning classes. I wanted to start this thread, and see if you can provide some explanation on the roles of the below classes, and how they generally interact with the main net class:
- Allocator
- Pruner
- Sparsifier
- Tracker
Any other extra information is also much appreciated.
Hi,
thanks for reaching out. here is an overview and please feel free to ask me any follow-up questions if you want to go into details about any of those:
The BaseCompressedNet
takes care of the overall compression and provides the high-level API for compressing a network.
The network compression itself is split up into multiple parts:
-
Allocator
: The allocator takes a overall desired prune ratio for the network and allocates a per-layer prune ratio. So in the simplest case the allocator could just be a constant per-layer prune ratio, for example. -
Pruner
: The pruner takes as input the per-layer prune ratio and determines how many weights should be pruned from each filter or neuron in the layer. So it's almost like a allocator within a layer. -
Sparsifier
: this class actually implements the sparsification. So as input it takes the per-neuron/filter sparsity determined by the pruner and sparsifies the weight tensor. -
Tracker
: is a convenient wrapper for pytorch's forward/backward hook functionality to track layer statistics that might be required during pruning.
Hi,
Thanks for the explanation. I was going through the code with the info you provided above and it helped me get a better understanding of how things works.
I just need a bit more explanation on how base classes for allocator and pruner work. I have hard time to understand the base logic behind these two classes, just by reading the code. Any help will be appreciated.
Also another question. I'm interested to play around with the structured sparsification of the model. Does the code remove the sparsed filters from the model? I would like to be able to get real performance boost on the hardware, after performing the sparsification.
Thanks,
This is currently not supported and/or I don't have plans to support it. If you are interested in inference time speed-ups, then I would recommend writing some type of post-processing tool that takes the network graph as input and outputs the slimmed model.
Alternatively, you can also hardcode the slimmed version for a particular architecture.
If you want to implement something like this, I would be happy to take PRs as well.