LVT
LVT copied to clipboard
get features from lvt network.
@Chenglin-Yang hi there, firstly thank you for making this code available.
I am trying to use lvt.py and I have set with_cls_head=False
So, now I do:
>>> import timm
imp>>> import torch
>>> from lvt import *
>>> model = timm.create_model('lvt', pretrained=False, num_classes=0, exportable=True)
>>> x=torch.rand(1,3,256,256)
>>> model(x).shape
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
AttributeError: 'list' object has no attribute 'shape'
I see that the model out now have 4 tensors as a list. How can I combine this? I would just like the features for downstream tasks.
I see that it has different dims:
>>> model(x)[0].shape
torch.Size([1, 64, 64, 64])
>>> model(x)[1].shape
torch.Size([1, 64, 32, 32])
>>> model(x)[2].shape
torch.Size([1, 160, 16, 16])
>>> model(x)[3].shape
torch.Size([1, 256, 8, 8])
How do I combine these to get a single output descriptor? for example resnet features would be 1,512
Thank you for your time.