pytorch-image-models
pytorch-image-models copied to clipboard
Transformer backbone for dense prediction
When using timm, the convolutional neural network can be used as the backbone(features_only=True, out_indices=(1, 2, 3, 4)) for dense prediction, but how can the transformer model be used directly as the backbone? Thanks!
@wwjwy supporting all of the transformer models in this manner is goal but taken longer than expected, possibly by end of summer. It's more work to figure out good schemes for all transformers given their differnces with flat features, non-hierarchical features (where to take featuers from, all end blocks, evenly distributed across layers, etc). Also, the extracted featuers commonly need extra norms, etc applied...
Thanks so much for the response!
I'm also look forward to that day, can't wait to see your great work!!👍🏻👍🏻👍🏻
@wwjwy @XL-H FYI, you can manually use some of the existing feature extraction helpers to make doing this yourself easier, you have to determine the feature modules/nodes manually by exploring the model though. This is really the big part of finishing the feature, determining the mapping for all models of what blocks/block names are of interested, and what format they are in (H, W, flat or not, needs extra normalization,e tc).
Extract block outputs
import torch
import timm
model = timm.create_model('vit_base_patch16_224', pretrained=True)
for n, m in m.named_modules():
print(n)
output:
....
blocks.10.norm2
blocks.10.mlp
blocks.10.mlp.fc1
blocks.10.mlp.act
blocks.10.mlp.drop1
blocks.10.mlp.fc2
blocks.10.mlp.drop2
blocks.10.ls2
blocks.10.drop_path2
blocks.11
blocks.11.norm1
blocks.11.attn
blocks.11.attn.qkv
blocks.11.attn.attn_drop
blocks.11.attn.proj
blocks.11.attn.proj_drop
blocks.11.ls1
blocks.11.drop_path1
blocks.11.norm2
blocks.11.mlp
blocks.11.mlp.fc1
blocks.11.mlp.act
blocks.11.mlp.drop1
blocks.11.mlp.fc2
blocks.11.mlp.drop2
blocks.11.ls2
blocks.11.drop_path2
...
fe = GraphExtractNet(m, ['blocks.9', 'blocks.10', 'blocks.11'])
o = fe(torch.randn(2, 3, 224, 224))
for x in o:
print(x.shape)
output:
torch.Size([2, 197, 768])
torch.Size([2, 197, 768])
torch.Size([2, 197, 768])
Extract attention maps
from torchvision.models.feature_extraction import get_graph_node_names
for n in get_graph_node_names(m)[0]:
print(n)
output:
...
blocks.10.attn.qkv
blocks.10.attn.reshape
blocks.10.attn.permute
blocks.10.attn.unbind
blocks.10.attn.getitem_3
blocks.10.attn.getitem_4
blocks.10.attn.getitem_5
blocks.10.attn.transpose
blocks.10.attn.matmul
blocks.10.attn.mul
blocks.10.attn.softmax
blocks.10.attn.attn_drop
blocks.10.attn.matmul_1
blocks.10.attn.transpose_1
blocks.10.attn.reshape_1
blocks.10.attn.proj
blocks.10.attn.proj_drop
blocks.10.ls1
blocks.10.drop_path1
blocks.10.add
blocks.10.norm2
blocks.10.mlp.fc1
blocks.10.mlp.act
blocks.10.mlp.drop1
blocks.10.mlp.fc2
blocks.10.mlp.drop2
blocks.10.ls2
blocks.10.drop_path2
blocks.10.add_1
blocks.11.norm1
blocks.11.attn.getattr
blocks.11.attn.getitem
blocks.11.attn.getitem_1
blocks.11.attn.getitem_2
blocks.11.attn.qkv
blocks.11.attn.reshape
blocks.11.attn.permute
blocks.11.attn.unbind
blocks.11.attn.getitem_3
blocks.11.attn.getitem_4
blocks.11.attn.getitem_5
blocks.11.attn.transpose
blocks.11.attn.matmul
blocks.11.attn.mul
blocks.11.attn.softmax
blocks.11.attn.attn_drop
blocks.11.attn.matmul_1
blocks.11.attn.transpose_1
blocks.11.attn.reshape_1
blocks.11.attn.proj
blocks.11.attn.proj_drop
blocks.11.ls1
blocks.11.drop_path1
blocks.11.add
blocks.11.norm2
blocks.11.mlp.fc1
blocks.11.mlp.act
blocks.11.mlp.drop1
blocks.11.mlp.fc2
blocks.11.mlp.drop2
blocks.11.ls2
blocks.11.drop_path2
blocks.11.add_1
norm
getitem_1
fc_norm
head
fe = GraphExtractNet(m, ['blocks.9.attn.softmax', 'blocks.10.attn.softmax', 'blocks.11.attn.softmax'])
o = fe(torch.randn(2, 3, 224, 224))
for x in o:
print(x.shape)
torch.Size([2, 12, 197, 197])
torch.Size([2, 12, 197, 197])
torch.Size([2, 12, 197, 197])
Thank you very much for your help, I will have a try!
Could you tell me how to import GraphExtractNet? name 'GraphExtractNet' is not defined
To access the named blocks try this:
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names
def get_model_attention(model, x, blocks=["blocks.11.attn.softmax"]):
model_attention = create_feature_extractor(model, blocks)
attention = model_attention(x)
return list(attention.values())[0].detach()
if __name__ == "__main__":
import timm
import torch
model = timm.create_model("beit_base_patch16_384")
input = torch.rand(1, 3, 384, 384)
attention = get_model_attention(model, input)
print(attention.shape)
Could you tell me how to import GraphExtractNet? name 'GraphExtractNet' is not defined
请问您解决如何提取特征的问题了吗? 可以看一下怎么处理的吗
To access the named blocks try this:
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names def get_model_attention(model, x, blocks=["blocks.11.attn.softmax"]): model_attention = create_feature_extractor(model, blocks) attention = model_attention(x) return list(attention.values())[0].detach() if __name__ == "__main__": import timm import torch model = timm.create_model("beit_base_patch16_384") input = torch.rand(1, 3, 384, 384) attention = get_model_attention(model, input) print(attention.shape)
TypeError: int() argument must be a string, a bytes-like object or a number, not 'Proxy'
my code:
from torchvision.models.feature_extraction import create_feature_extractor
def get_model_attention(model, x, blocks=["layers.3.blocks.1.attn.softmax"]):
model_attention = create_feature_extractor(model, blocks)
attention = model_attention(x)
return list(attention.values())[0].detach()
if __name__ == "__main__":
import timm
import torch
model = timm.create_model('swin_base_patch4_window7_224', pretrained=True, in_chans=5, num_classes=1024)
input = torch.randn((1, 5, 224, 224))
# x = torch.randn((1, 5, 224, 224))
attention = get_model_attention(model, input)
print(attention.shape)
how to fix this?
Could you please use swin v2 to extract the feature as an example?
Thank you.
Just a heads up, code for extracting a dense (not 14 x 14 patch) feature map was shared here from a pretrained ViT.
There are two caveats to using their code:
- Their code currently only works with the model forward defined in DINO (which is based on the timm code, as they acknowledge)
- Their current hook for their "query/key/value" is redundant, and calls the qkv layer unnecessarily. To fix that this line should be modified to register the hook on
block.attn.qkvinstead ofblock.attn, and have this line modified toqkv = output.reshape(B, N, 3, module.num_heads, C // module.num_heads).permute(2, 0, 3, 1, 4)
Doing so you can extract features far denser than the current 14x14 patch for ViT patch-16.