pytorch-vsumm-reinforce
pytorch-vsumm-reinforce copied to clipboard
GoogLe Net implementation
In which part of the code is the GoogLe Net as the first part of the DSN specified?
from torchvision.models import googlenet
import torch
model = googlenet(pretrained=True)
extractor = torch.nn.Sequential(*list(model.children())[:-2])
im = torch.randn(1,3,720,1280) # NCHW
feature = extractor(im).cpu().numpy().flatten() # [1,1024,1,1] -> [1024]
i try like this...
I wonder if he is using the pool5 layer of the googlenet network for feature extraction, so is that the code you wrote? Or is there some other additional code.
I wonder if he is using the pool5 layer of the googlenet network for feature extraction, so is that the code you wrote? Or is there some other additional code.
I write the code myself.
@ehdrndd Can you share your feature extraction code? Or give a link, thanks! You can add a contact if it is convenient
@ruanzhijian It may help you.
https://github.com/HERIUN/vsumm-reinforce_re/blob/main/generate_dataset.py