pytorch-image-models icon indicating copy to clipboard operation
pytorch-image-models copied to clipboard

Transformer backbone for dense prediction

Open wwjwy opened this issue 3 years ago • 11 comments

When using timm, the convolutional neural network can be used as the backbone(features_only=True, out_indices=(1, 2, 3, 4)) for dense prediction, but how can the transformer model be used directly as the backbone? Thanks!

wwjwy avatar Jul 28 '22 08:07 wwjwy

@wwjwy supporting all of the transformer models in this manner is goal but taken longer than expected, possibly by end of summer. It's more work to figure out good schemes for all transformers given their differnces with flat features, non-hierarchical features (where to take featuers from, all end blocks, evenly distributed across layers, etc). Also, the extracted featuers commonly need extra norms, etc applied...

rwightman avatar Aug 01 '22 17:08 rwightman

Thanks so much for the response!

wwjwy avatar Aug 02 '22 00:08 wwjwy

I'm also look forward to that day, can't wait to see your great work!!👍🏻👍🏻👍🏻

rainbow-xiao avatar Aug 03 '22 15:08 rainbow-xiao

@wwjwy @XL-H FYI, you can manually use some of the existing feature extraction helpers to make doing this yourself easier, you have to determine the feature modules/nodes manually by exploring the model though. This is really the big part of finishing the feature, determining the mapping for all models of what blocks/block names are of interested, and what format they are in (H, W, flat or not, needs extra normalization,e tc).

Extract block outputs

import torch
import timm
model = timm.create_model('vit_base_patch16_224', pretrained=True)
for n, m in m.named_modules():
    print(n)

output:

....
blocks.10.norm2
blocks.10.mlp
blocks.10.mlp.fc1
blocks.10.mlp.act
blocks.10.mlp.drop1
blocks.10.mlp.fc2
blocks.10.mlp.drop2
blocks.10.ls2
blocks.10.drop_path2
blocks.11
blocks.11.norm1
blocks.11.attn
blocks.11.attn.qkv
blocks.11.attn.attn_drop
blocks.11.attn.proj
blocks.11.attn.proj_drop
blocks.11.ls1
blocks.11.drop_path1
blocks.11.norm2
blocks.11.mlp
blocks.11.mlp.fc1
blocks.11.mlp.act
blocks.11.mlp.drop1
blocks.11.mlp.fc2
blocks.11.mlp.drop2
blocks.11.ls2
blocks.11.drop_path2
...
fe = GraphExtractNet(m, ['blocks.9', 'blocks.10', 'blocks.11'])
o = fe(torch.randn(2, 3, 224, 224))
for x in o:
    print(x.shape)

output:

torch.Size([2, 197, 768])
torch.Size([2, 197, 768])
torch.Size([2, 197, 768])

Extract attention maps

from torchvision.models.feature_extraction import get_graph_node_names
for n in get_graph_node_names(m)[0]:
    print(n)  

output:

...
blocks.10.attn.qkv
blocks.10.attn.reshape
blocks.10.attn.permute
blocks.10.attn.unbind
blocks.10.attn.getitem_3
blocks.10.attn.getitem_4
blocks.10.attn.getitem_5
blocks.10.attn.transpose
blocks.10.attn.matmul
blocks.10.attn.mul
blocks.10.attn.softmax
blocks.10.attn.attn_drop
blocks.10.attn.matmul_1
blocks.10.attn.transpose_1
blocks.10.attn.reshape_1
blocks.10.attn.proj
blocks.10.attn.proj_drop
blocks.10.ls1
blocks.10.drop_path1
blocks.10.add
blocks.10.norm2
blocks.10.mlp.fc1
blocks.10.mlp.act
blocks.10.mlp.drop1
blocks.10.mlp.fc2
blocks.10.mlp.drop2
blocks.10.ls2
blocks.10.drop_path2
blocks.10.add_1
blocks.11.norm1
blocks.11.attn.getattr
blocks.11.attn.getitem
blocks.11.attn.getitem_1
blocks.11.attn.getitem_2
blocks.11.attn.qkv
blocks.11.attn.reshape
blocks.11.attn.permute
blocks.11.attn.unbind
blocks.11.attn.getitem_3
blocks.11.attn.getitem_4
blocks.11.attn.getitem_5
blocks.11.attn.transpose
blocks.11.attn.matmul
blocks.11.attn.mul
blocks.11.attn.softmax
blocks.11.attn.attn_drop
blocks.11.attn.matmul_1
blocks.11.attn.transpose_1
blocks.11.attn.reshape_1
blocks.11.attn.proj
blocks.11.attn.proj_drop
blocks.11.ls1
blocks.11.drop_path1
blocks.11.add
blocks.11.norm2
blocks.11.mlp.fc1
blocks.11.mlp.act
blocks.11.mlp.drop1
blocks.11.mlp.fc2
blocks.11.mlp.drop2
blocks.11.ls2
blocks.11.drop_path2
blocks.11.add_1
norm
getitem_1
fc_norm
head
fe = GraphExtractNet(m, ['blocks.9.attn.softmax', 'blocks.10.attn.softmax', 'blocks.11.attn.softmax'])
o = fe(torch.randn(2, 3, 224, 224))
for x in o:
    print(x.shape)
torch.Size([2, 12, 197, 197])
torch.Size([2, 12, 197, 197])
torch.Size([2, 12, 197, 197])

rwightman avatar Aug 03 '22 17:08 rwightman

Thank you very much for your help, I will have a try!

wwjwy avatar Aug 04 '22 01:08 wwjwy

Could you tell me how to import GraphExtractNet? name 'GraphExtractNet' is not defined

wwjwy avatar Aug 04 '22 02:08 wwjwy

To access the named blocks try this:


from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names

def get_model_attention(model, x, blocks=["blocks.11.attn.softmax"]):
    model_attention = create_feature_extractor(model, blocks)
    attention = model_attention(x)
    return list(attention.values())[0].detach()


if __name__ == "__main__":
    import timm
    import torch
    model = timm.create_model("beit_base_patch16_384")
    input = torch.rand(1, 3, 384, 384)
    attention = get_model_attention(model, input)
    print(attention.shape)

lukasfolle avatar Aug 15 '22 08:08 lukasfolle

Could you tell me how to import GraphExtractNet? name 'GraphExtractNet' is not defined

请问您解决如何提取特征的问题了吗? 可以看一下怎么处理的吗

Bailey-24 avatar Dec 10 '22 13:12 Bailey-24

To access the named blocks try this:

from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names

def get_model_attention(model, x, blocks=["blocks.11.attn.softmax"]):
    model_attention = create_feature_extractor(model, blocks)
    attention = model_attention(x)
    return list(attention.values())[0].detach()


if __name__ == "__main__":
    import timm
    import torch
    model = timm.create_model("beit_base_patch16_384")
    input = torch.rand(1, 3, 384, 384)
    attention = get_model_attention(model, input)
    print(attention.shape)

TypeError: int() argument must be a string, a bytes-like object or a number, not 'Proxy'

my code:

from torchvision.models.feature_extraction import create_feature_extractor

def get_model_attention(model, x, blocks=["layers.3.blocks.1.attn.softmax"]):
    model_attention = create_feature_extractor(model, blocks)
    attention = model_attention(x)
    return list(attention.values())[0].detach()


if __name__ == "__main__":
    import timm
    import torch
    model = timm.create_model('swin_base_patch4_window7_224', pretrained=True, in_chans=5, num_classes=1024)
    input = torch.randn((1, 5, 224, 224))
    # x = torch.randn((1, 5, 224, 224))
    attention = get_model_attention(model, input)
    print(attention.shape)

how to fix this?

Bailey-24 avatar Dec 10 '22 13:12 Bailey-24

Could you please use swin v2 to extract the feature as an example?

Thank you.

Bailey-24 avatar Dec 10 '22 13:12 Bailey-24

Just a heads up, code for extracting a dense (not 14 x 14 patch) feature map was shared here from a pretrained ViT.

There are two caveats to using their code:

  1. Their code currently only works with the model forward defined in DINO (which is based on the timm code, as they acknowledge)
  2. Their current hook for their "query/key/value" is redundant, and calls the qkv layer unnecessarily. To fix that this line should be modified to register the hook on block.attn.qkv instead of block.attn, and have this line modified to qkv = output.reshape(B, N, 3, module.num_heads, C // module.num_heads).permute(2, 0, 3, 1, 4)

Doing so you can extract features far denser than the current 14x14 patch for ViT patch-16.

aluo-x avatar Apr 07 '23 19:04 aluo-x