LVT icon indicating copy to clipboard operation
LVT copied to clipboard

get features from lvt network.

Open jamesrobertwilliams opened this issue 3 years ago • 0 comments

@Chenglin-Yang hi there, firstly thank you for making this code available.

I am trying to use lvt.py and I have set with_cls_head=False

So, now I do:

>>> import timm
imp>>> import torch
>>> from lvt import *
>>> model = timm.create_model('lvt', pretrained=False, num_classes=0, exportable=True)
>>> x=torch.rand(1,3,256,256)
>>> model(x).shape
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
AttributeError: 'list' object has no attribute 'shape'

I see that the model out now have 4 tensors as a list. How can I combine this? I would just like the features for downstream tasks.

I see that it has different dims:

>>> model(x)[0].shape
torch.Size([1, 64, 64, 64])
>>> model(x)[1].shape
torch.Size([1, 64, 32, 32])
>>> model(x)[2].shape
torch.Size([1, 160, 16, 16])
>>> model(x)[3].shape
torch.Size([1, 256, 8, 8])

How do I combine these to get a single output descriptor? for example resnet features would be 1,512

Thank you for your time.

jamesrobertwilliams avatar Aug 19 '22 10:08 jamesrobertwilliams