apex icon indicating copy to clipboard operation
apex copied to clipboard

How to use ASP to transformer model in mmdetection project

Open erwangccc opened this issue 1 year ago • 2 comments

Hi, everyone

As the title says, i want to sparse automatically to a transformer model using ASP, does it work and how to use ASP in mmdetection project?

Thanks in advance~

erwangccc avatar Oct 14 '22 03:10 erwangccc

I add model = ASP.prune_trained_model(model, optimizer) to mmdetection project. The following error occurs:

Traceback (most recent call last):
  File "tools/train.py", line 248, in <module>
    main()
  File "tools/train.py", line 238, in main
    custom_train_model(model,
  File "/home/xx/xx/projects/mmdet3d_plugin/bevformer/apis/train.py", line 30, in custom_train_model
    custom_train_detector(model,
  File "/home/xx/xx/projects/mmdet3d_plugin/xx/apis/mmdet_train.py", line 92, in custom_train_detector
    model = ASP.prune_trained_model(model, optimizer)
  File "/opt/conda/lib/python3.8/site-packages/apex/contrib/sparsity/asp.py", line 214, in prune_trained_model
    cls.init_model_for_pruning(model, mask_calculator="m4n2_1d", verbosity=2, whitelist=[torch.nn.Linear, torch.nn.Conv2d], allow_recompute_mask=False)
  File "/opt/conda/lib/python3.8/site-packages/apex/contrib/sparsity/asp.py", line 124, in init_model_for_pruning
    add_sparse_attributes(name, sparse_module)
  File "/opt/conda/lib/python3.8/site-packages/apex/contrib/sparsity/asp.py", line 96, in add_sparse_attributes
    sparse_parameters = sparse_parameter_list[type(module)]
KeyError: <class 'mmcv.cnn.bricks.wrappers.Linear'>

So How to use ASP correctly in mmdet? Thanks.

erwangccc avatar Oct 14 '22 03:10 erwangccc

cc @ChongyuNVIDIA @jpool-nv

crcrpar avatar Oct 14 '22 06:10 crcrpar

Hi, @crcrpar @ChongyuNVIDIA @jpool-nv

Any comments for it? Thanks in advance~

erwangccc avatar Oct 17 '22 02:10 erwangccc

Hi @erwangccc ,

The mmcv wrapper's module type is not present in the default sparse_parameter_list. You can add the required module->parameter mapping with the custom_layer_dict argument to init_model_for_pruning(). You can see how this argument is used here.

To do this, you'll need to call the three constituent functions inside of prune_trained_model() (instead of this convenience function) so you can call init_model_for_pruning() directly to supply your custom_layer_dict.

jpool-nv avatar Oct 31 '22 15:10 jpool-nv