pytorch-image-models icon indicating copy to clipboard operation
pytorch-image-models copied to clipboard

[FEATURE] Chaining pooled output to classifier

Open ZeyuSun opened this issue 1 year ago • 0 comments

Motivation

Chaining unpooled output to classifier has been implemented and can be done as follows:

model = timm.create_model('vit_medium_patch16_reg1_gap_256', pretrained=True)
output = model.forward_features(torch.randn(2,3,256,256))
classified = model.forward_head(output)

Compared the convolutional layers outputs ("pre-classifier" features), the penultimate linear layer outputs (pre-logits features) are equally useful in many tasks. For example, we may want a vector embedding of an image to compute the intrinsic distance between two images.

Current solutions are inefficient

The two current solutions create a new network or changes the network in-place:

  • Create with no classifier:
    m = timm.create_model('resnet50', pretrained=True, num_classes=0)
    
  • Remove it later:
    m = timm.create_model('ese_vovnet19b_dw', pretrained=True)
    m.reset_classifier(0)
    

If I need to collect the logits and pre-logits features by iterating through the batches, I need to define two networks that only differs in the last layer. This is not optimal because all the boiler plates have to be replicated and it may cause out-of-memory for large networks.

Potential solutions

One potential solution to get pre-logits features is using model.forward_head with pre_logits = True. This works for most networks, but some networks do not accept the pre_logits argument:

repghost.py:304:    def forward_head(self, x):
ghostnet.py:291:    def forward_head(self, x):
inception_v3.py:374:    def forward_head(self, x):
tiny_vit.py:548:    def forward_head(self, x):
nasnet.py:556:    def forward_head(self, x):

A more general alternative solution is to set up a common interface to cut the network in halves and chain them. This requires to pass where we cut the network to forward_features and forward_head. However, this generality may not be necessary because the convolution layers and the pre-logit layers are arguably the two most important intermediate features.

ZeyuSun avatar May 24 '24 18:05 ZeyuSun