How to use ASP to transformer model in mmdetection project
Hi, everyone
As the title says, i want to sparse automatically to a transformer model using ASP, does it work and how to use ASP in mmdetection project?
Thanks in advance~
I add model = ASP.prune_trained_model(model, optimizer) to mmdetection project.
The following error occurs:
Traceback (most recent call last):
File "tools/train.py", line 248, in <module>
main()
File "tools/train.py", line 238, in main
custom_train_model(model,
File "/home/xx/xx/projects/mmdet3d_plugin/bevformer/apis/train.py", line 30, in custom_train_model
custom_train_detector(model,
File "/home/xx/xx/projects/mmdet3d_plugin/xx/apis/mmdet_train.py", line 92, in custom_train_detector
model = ASP.prune_trained_model(model, optimizer)
File "/opt/conda/lib/python3.8/site-packages/apex/contrib/sparsity/asp.py", line 214, in prune_trained_model
cls.init_model_for_pruning(model, mask_calculator="m4n2_1d", verbosity=2, whitelist=[torch.nn.Linear, torch.nn.Conv2d], allow_recompute_mask=False)
File "/opt/conda/lib/python3.8/site-packages/apex/contrib/sparsity/asp.py", line 124, in init_model_for_pruning
add_sparse_attributes(name, sparse_module)
File "/opt/conda/lib/python3.8/site-packages/apex/contrib/sparsity/asp.py", line 96, in add_sparse_attributes
sparse_parameters = sparse_parameter_list[type(module)]
KeyError: <class 'mmcv.cnn.bricks.wrappers.Linear'>
So How to use ASP correctly in mmdet? Thanks.
cc @ChongyuNVIDIA @jpool-nv
Hi, @crcrpar @ChongyuNVIDIA @jpool-nv
Any comments for it? Thanks in advance~
Hi @erwangccc ,
The mmcv wrapper's module type is not present in the default sparse_parameter_list. You can add the required module->parameter mapping with the custom_layer_dict argument to init_model_for_pruning(). You can see how this argument is used here.
To do this, you'll need to call the three constituent functions inside of prune_trained_model() (instead of this convenience function) so you can call init_model_for_pruning() directly to supply your custom_layer_dict.