PyTorch-Pretrained-ViT icon indicating copy to clipboard operation
PyTorch-Pretrained-ViT copied to clipboard

Extract the transformer intermediate layer

Open leolv131 opened this issue 3 years ago • 1 comments

I want extract the transformer intermediate layer. I use follow code, but it does not work. nn.Sequential(*list(model.children()), how should i do?

leolv131 avatar Jan 13 '22 12:01 leolv131

Hi, I was trying to accomplish the same thing using the same code. Instead I created a different class as an implementation of the Vit class which overwrites the forward pass to circumvent the last two layers that are used for classification.

import pytorch_pretrained_vit as ptv
from pytorch_pretrained_vit.model import PositionalEmbedding1D

class EncoderVit(ptv.ViT):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.positional_embedding = PositionalEmbedding1D(576, 768)
        
    def forward(self, x):
        x = self.patch_embedding(x).flatten(2).transpose(1, 2)
        # x = torch.cat((model.class_token.expand(1, -1, -1), x), dim=1)
        x = self.positional_embedding(x)
        x = self.transformer(x)
        return x

I needed this to use the ViT as an encoder, and I'm guessing you do too. Hope this helps!

Abeldewit avatar Jul 05 '22 11:07 Abeldewit