Towards-Unified-Surgical-Skill-Assessment icon indicating copy to clipboard operation
Towards-Unified-Surgical-Skill-Assessment copied to clipboard

Request for pre processing code release

Open mustafa1728 opened this issue 3 years ago • 2 comments

Thank you for this great work.

I was experimenting with the codebase for JIGSAWS, and wanted to explore deeper into how the features were prepared. In particular, I was interested in the resent 101 features with ten-crop augmentation for the visual path. Can you please add the pre-processing code for the visual path in the repository?

Thank you.

mustafa1728 avatar Dec 20 '21 09:12 mustafa1728

Hi mustafa, thank you for the question.

The ResNet features are extracted using the pre-trained ResNet-101 in torchvision. And the augmentation is just the TenCrop transform in torchvision. I don't have the pre-processing code at hands now but I think it can be reproduced easily.

Finspire13 avatar Dec 21 '21 09:12 Finspire13

Thank you for your reply.

I tried reproducing the feature extraction by processing each frame of the video separately and concatenating the 2048 dimensional features together. For processing, I used the class attached in the code snippet below.

However, I am getting different SROCC values for the V path, compared to the features provided by you. I suspect I have made some mistakes in the crop size of the Ten Crop transform or the normalisation, but I am not sure.

Can you please help me out with this?

Thank You.

from torchvision.models import resnet101 as ResNet101

pre_model = ResNet101(
    pretrained = True,
    progress=True, 
)

class FeatureExtractor():
    def __init__(self):

        modules = list(pre_model.children())[:-1]
        self.model = torch.nn.Sequential(*modules)
        for p in self.model.parameters():
            p.requires_grad = False

        self.pre_transform = Compose([
            Normalize((123.675, 116.28, 103.53), (58.395, 57.12, 57.375)),
        ])
        self.transform = Compose([
            TenCrop((256, 352)), 
        ])
    
    def __call__(self, image):
        image = self.pre_transform(image)
        x = self.transform(image)
        x = torch.stack(x, axis=0)
        feat = self.model(x)
        return feat


mustafa1728 avatar Dec 26 '21 12:12 mustafa1728