PyTorch-Pretrained-ViT
PyTorch-Pretrained-ViT copied to clipboard
Extract the transformer intermediate layer
I want extract the transformer intermediate layer. I use follow code, but it does not work. nn.Sequential(*list(model.children()), how should i do?
Hi, I was trying to accomplish the same thing using the same code. Instead I created a different class as an implementation of the Vit class which overwrites the forward pass to circumvent the last two layers that are used for classification.
import pytorch_pretrained_vit as ptv
from pytorch_pretrained_vit.model import PositionalEmbedding1D
class EncoderVit(ptv.ViT):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.positional_embedding = PositionalEmbedding1D(576, 768)
def forward(self, x):
x = self.patch_embedding(x).flatten(2).transpose(1, 2)
# x = torch.cat((model.class_token.expand(1, -1, -1), x), dim=1)
x = self.positional_embedding(x)
x = self.transformer(x)
return x
I needed this to use the ViT as an encoder, and I'm guessing you do too. Hope this helps!