SimplePruning
SimplePruning copied to clipboard
cnn pruning with tensorflow.
SimplePruning
This repository provides a cnn channels pruning demo with tensorflow. You can pruning your own model(support conv2d,depthwise conv2d,pool,fc,concat, add ops and so on) defined in modelsets.py. Have a good time!
- Author: Haibo Wang
- Email: [email protected]
- Home Page: dasuda.top
Dependencies
Tensorflow >= 1.10.0python >= 3.5opencv-python >= 4.1.0numpy >= 1.14.5matplotlib >= 3.0.3
Getting Started
-
Clone the repository
$ git clone https://github.com/DasudaRunner/SimplePruning.git
-
Downdload the Cifar10 dataset, and put into cifar-10-python/
Url:
http://www.cs.toronto.edu/~kriz/cifar.html -
(Optional) Define your model in modesets.py
You must use add_layer() API defined in pruner.py to set up your model. More details to modelsets.py -
(Optional) Config params in utils/config.py
e.g. model name, learning rate, pruning rate. -
Train a full model, .ckpt and .pb model file will be saved in ckpt_model/
$ python full_train.py
-
Channel pruning. .ckpt and .pb model file will be saved in channels_pruned_model/
$ python channels_pruning.py
Supported ops (Tensorflow)
- Conv2d
- FullyConnected
- MaxPooling, AveragePooling
- BatchNormalization
- Activation
- DepthwiseConv2d
- GlobalMaxPooling, GlobalAveragePooling
- Concat
- Add
- Flatten
Evaluation on Cifar10 dataset
| Model | Dataset | Pruning rate | Model size / MB | Inference time / ms*64pic |
|---|---|---|---|---|
| SimpleNet | cifar-10 | 0.5 | 8.7 -> 1.8 | 5.8 -> 2.7 |
| VGG19 | cifar-10 | 0.5 | 53.4 -> 13.5 | 28.62 -> 9.44 |
| DenseNet40 | cifar-10 | 0.5 | 4.3 -> 1.5 | 77.87 -> 39.97 |
| MobileNet V1 | cifar-10 | 0.5 | 6.6 -> 1.8 | 19.39 -> 8.01 |
| OCR-Net | --- | 0.5 | 2426.2 -> 841.9 | 10.36->7.3 |
Update logs
- 2019.07.24
- Add support for
Addop. - Add support for ResNet18/ResNet34 in modelsets.py.
- Add support for
- 2019.07.16
- Add support for
Concatop. - Add support for DenseNet40 in modelsets.py.
- Add support for
- 2019.07.14
- Reconsitution SimplePruning.