pytorch-image-models
pytorch-image-models copied to clipboard
[BUG] Feature extraction and gradient checkpointing
In models with feature extraction that use FeatureListNet, trying to enable gradient checkpointing results in an error.
To Reproduce Steps to reproduce the behavior:
- Start Google Colab
- Run the following:
!pip install timm
import timm
myModel = timm.create_model('resnet50', features_only=False)
myModel.set_grad_checkpointing()
...
# as expected
myModel = timm.create_model('resnet50', features_only=True)
myModel.set_grad_checkpointing()
# fails with 'FeatureListNet' object has no attribute 'set_grad_checkpointing'
Expected behavior
Calling .set_grad_checkpointing() on a model with features_only=True should work and the forward pass should result in the pyramid features but calculated with gradient checkpointing.
Desktop (please complete the following information): Test was done in Google Colab.
- OS: ubuntu 18.04
- timm-0.6.12
- torch 1.13.0+cu116
+1
Fixed in https://github.com/rwightman/pytorch-image-models/commit/2cfff0581b643347673cd19713cb3ea3b09c77a9