CLIP icon indicating copy to clipboard operation
CLIP copied to clipboard

Reproduce zeroshot results on CIFAR10 and MNIST dataset.

Open xcpeng opened this issue 2 years ago • 12 comments

Thank you for your work on CLIP!

I was trying to reproduce the zeroshot prediction results listed in Table 11 in the paper. Though I can reproduce most of the results in the Table 11, I found there are huge gaps on CIFAR10 and MNIST datasets.

Here I show the results using released RN50 models

Model RN50
Dataset CLIP My result Gap
MNIST 66.6 58.2 -8.4
CIFAR10 75.6 71.5 -4.1

I have attached my implementation below to reproduce my results. Could you help me to check what leads to the gap on CIFAR10 and MNIST?

With this implementation, I can reproduce most of the results in Table 11, except for 4 or 5 datasets (including MNIST and CIFAR10).

import torch
import clip
import os
from torchvision.datasets import  MNIST, CIFAR10
import numpy as np

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('RN50', device)

# from https://github.com/openai/CLIP/blob/main/data/prompts.md
mnist_classes = ['0','1','2','3','4','5','6','7','8','9',]
mnist_templates = ['a photo of the number: "{}".',]
cifar10_classes = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck',]
cifar10_templates = [
    'a photo of a {}.',
    'a blurry photo of a {}.',
    'a black and white photo of a {}.',
    'a low contrast photo of a {}.',
    'a high contrast photo of a {}.',
    'a bad photo of a {}.',
    'a good photo of a {}.',
    'a photo of a small {}.',
    'a photo of a big {}.',
    'a photo of the {}.',
    'a blurry photo of the {}.',
    'a black and white photo of the {}.',
    'a low contrast photo of the {}.',
    'a high contrast photo of the {}.',
    'a bad photo of the {}.',
    'a good photo of the {}.',
    'a photo of the small {}.',
    'a photo of the big {}.',
]


class_map = {'MNIST': mnist_classes, 'CIFAR10': cifar10_classes}
template_map = {'MNIST': mnist_templates, 'CIFAR10': cifar10_templates}

@torch.no_grad()
def accuracy(output, target, topk=(1,)):
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.reshape(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size).item())
    return res


@torch.no_grad()
def extract_text_features(dataset_name):
    # code borrowed from: https://github.com/openai/CLIP/blob/fcab8b6eb92af684e7ff0a904464be7b99b49b88/notebooks/Prompt_Engineering_for_ImageNet.ipynb
    class_names = class_map[dataset_name]
    templates = template_map[dataset_name]
    model.to(device)
    model.eval()

    zeroshot_weights = []
    for classname in class_names:
        texts = [template.format(classname) for template in templates]
        texts = clip.tokenize(texts).to(device)
        class_embeddings = model.encode_text(texts)
        class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
        class_embedding = class_embeddings.mean(dim=0)
        class_embedding /= class_embedding.norm()
        zeroshot_weights.append(class_embedding)
    zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device)
    return zeroshot_weights

mnist = MNIST(root=os.path.expanduser("~/.cache"), download=True, train=False)
cifar10 = CIFAR10(root=os.path.expanduser("~/.cache"), download=True, train=False)

for dataset in [mnist, cifar10]:
    # extract image feature, code borrowed from: https://github.com/openai/CLIP#zero-shot-prediction
    image_features = []
    image_labels = []
    for image, class_id in dataset:
        image_input = preprocess(image).unsqueeze(0).to(device)
        with torch.no_grad():
            image_feature = model.encode_image(image_input)
        image_feature /= image_feature.norm()
        image_features.append(image_feature)
        image_labels.append(class_id)
    image_features = torch.stack(image_features, dim=1).to(device)
    image_features = image_features.squeeze()
    
    # extract text feature
    dataset_name = 'MNIST' if dataset == mnist else 'CIFAR10'
    text_features = extract_text_features(dataset_name)
    
    # compute top-1 accuracy
    logits = (100. * image_features @ text_features).softmax(dim=-1)
    image_labels = torch.tensor(image_labels).unsqueeze(dim=1).to(device)
    top1_acc = accuracy(logits, image_labels, (1,))
    print(f'top-1 accuracy for {dataset_name} dataset: {top1_acc[0]:.3f}')

xcpeng avatar Oct 27 '21 17:10 xcpeng

In addition, I also checked the implementation in https://github.com/openai/CLIP/blob/fcab8b6eb92af684e7ff0a904464be7b99b49b88/notebooks/Prompt_Engineering_for_ImageNet.ipynb

I changed the dataset from ImageNet in the notebook to CIFAR10, and got the result: 71.48%, which is very close to the result I got from my implementation (71.5%).

I have attached the ipynb file.

zeroshot_cifar.zip

xcpeng avatar Oct 27 '21 18:10 xcpeng

The result of my implementation is 72.25% on cifar10 and 41.01% on cifar100. I use template of imagenet.

tfwang08 avatar Nov 26 '21 01:11 tfwang08

I tried other models on cifar10, and the results were lower than those reported in the paper.

tfwang08 avatar Nov 26 '21 02:11 tfwang08

@xcpeng @tfwang96 Also get 71.5% zero-shot acc on cifar10. Have you solved this problem? I suspect that the author may use a different data preprocess for cifar10.

vtddggg avatar Dec 23 '21 08:12 vtddggg

No, I haven't figured out what causes the gap. Would be great if the authors can reveal more details on these datasets.

msxingpeng avatar Dec 28 '21 08:12 msxingpeng

Me too.

tfwang08 avatar Dec 29 '21 00:12 tfwang08

Have you guys tried UCF101? I obtain 58.6% with ViT-B/32, but the reported is 64.5%.

dengandong avatar Sep 23 '22 01:09 dengandong

Have you guys tried UCF101? I obtain 58.6% with ViT-B/32, but the reported is 64.5%.

Unfortunately no

msxingpeng avatar Sep 23 '22 01:09 msxingpeng

I was trying to reproduce the zeroshot but found there are It's only 49 accurate on MNIST datasets. Could you help me to check what is wrong?

model, processor = clip.load('RN50')
test_loader = datasets.MNIST('data', train=False, transform=processor,download=True)

def zeroshot_classifier(model,classnames):
    with torch.no_grad():
        print('Loading  texts template') 
        zeroshot_weights = []
        template = ['a photo of the number: "{}".',]
        for classname in tqdm(classnames):
            texts = [t.format(classname) for t in template] #format with class
            texts = clip.tokenize(texts).cuda() #tokenize
            class_embeddings = model.encode_text(texts) #embed with text encoder
            class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
            class_embedding = class_embeddings.mean(dim=0)
            class_embedding /= class_embedding.norm()
            zeroshot_weights.append(class_embedding)
        zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
    return zeroshot_weights

def test(model,device,name):
    print("device is {}".format(device))
    correct_num = torch.tensor(0).to(device)

    text_features = zeroshot_classifier(model,test_loader.classes)
    
    for image, label in tqdm(DataLoader(test_loader,batch_size=10)):

                with torch.no_grad():
                  features = model.encode_image(image.to(device))
                features /= features.norm(dim=-1, keepdim=True)
                
                similarity = (100.0 * features @ text_features)
                probs = similarity.softmax(dim=-1)

                _, pred = torch.max(probs, 1)
                num = torch.sum(pred==label.to(device))

                correct_num = correct_num + num
                torch.cuda.empty_cache()

    print (correct_num)
    print ('\n{} correct rate is {}'.format(name,correct_num/10000))

test(model.to(device) , device , 'CLIP model')

serena-li avatar Nov 11 '22 10:11 serena-li

Have you guys tried UCF101? I obtain 58.6% with ViT-B/32, but the reported is 64.5%.

Were you ever able to replicate their results? I am using a single center frame but only getting ~53% accuracy... not sure what is different.

EDIT: This notebook helped a lot! https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb

zanedurante avatar Feb 15 '23 05:02 zanedurante

I thought this may be because CIFAR-10 would have different mean and variance. So I did following, before classifying:

from torchvision import transforms
#Got values from: https://stackoverflow.com/a/68123869/1953366
mean = (0.49139968, 0.48215827, 0.44653124)
std = (0.24703233, 0.24348505, 0.26158768)
preprocess.transforms[-1] = transforms.Normalize(mean, std)

This changed preprocess to:

Compose(
    Resize(size=224, interpolation=bicubic, max_size=None, antialias=warn)
    CenterCrop(size=(224, 224))
    <function _convert_image_to_rgb at 0x7f95ac305ab0>
    ToTensor()
    Normalize(mean=(0.49139968, 0.48215827, 0.44653124), std=(0.24703233, 0.24348505, 0.26158768))
)

Sadly for RN50 it reduced the accuracy for CIFAR-10 to 70.17

aknirala avatar Jun 13 '23 22:06 aknirala

Please refer to https://github.com/openai/CLIP/blob/fcab8b6eb92af684e7ff0a904464be7b99b49b88/notebooks/Prompt_Engineering_for_ImageNet.ipynb for this concern.

shyammarjit avatar Sep 28 '23 13:09 shyammarjit