lit-code icon indicating copy to clipboard operation
lit-code copied to clipboard

Lower top-1 accuracy on pre-trained CIFAR10 Resnet models

Open curious-synapse opened this issue 4 years ago • 2 comments

I just computed the accuracies of the pre-trained resnet models on CIFAR-10 (using torch==1.8.1) and the student accuracy results are much lower than what is shown in the paper.

Here are the numbers I get: Teacher(resnet110) = 93.68% Student(resnet20) = 91.85% [in the paper this looks to be around 93.09% from the plot] Student(resnet32) = 92.97% Student(resnet44) = 93.36% Student(resnet56) = 93.52% Student(resnet110) = 93.7%

Any ideas on why the performance is lower?

Here is the script I used to compute the accuracies

import argparse
import numpy as np
import random
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.cuda
import torch.optim
import torch.utils.data
import dataloaders
import resnet

#################
# Model Specs
teacher_arch = "resnet110"
teacher_model_checkpoint = "./teacher/resnet110.pth"
batch_size = 32
##################
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', default=-1, type=int)
    args = parser.parse_args()

    # Determinism
    if args.seed >= 0:
        seed = args.seed
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        torch.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed)

    train_loader, val_loader = get_data_loaders()
    print("Loaded Data")

    teacher = load_model(teacher_arch, teacher_model_checkpoint)
    print("Teacher = ", teacher_arch)
    val_acc1 = compute_accuracy(teacher, val_loader, train=False)
    print("Teacher Validation accuracy = ", val_acc1)

    for sa in ["resnet20", "resnet32", "resnet44", "resnet56", "resnet110"]:
        sa_checkpoint = "./student/" + sa + ".pth"
        student = load_model(sa, sa_checkpoint)
        val_acc1 = compute_accuracy(student, val_loader, train=False)
        print("Student = ", sa)
        print("Student Validation accuracy = ", val_acc1)

def compute_accuracy(model, loader, train=True):
    model.eval()

    accuracies1 = AverageMeter()

    for i, (inp, target) in enumerate(loader):
        target = target.cuda(non_blocking=True)
        inp = inp.cuda(non_blocking=True).detach()

        with torch.no_grad():
            out = model(inp)

        with torch.no_grad():
            prec1 = accuracy(out, target, topk=(1,))
            accuracies1.update(prec1[0], inp.size(0))

    return accuracies1.avg

def get_data_loaders():
    return dataloaders.CIFAR10DataLoaders.train_loader(batch_size=batch_size), dataloaders.CIFAR10DataLoaders.val_loader()

def load_model(arch, model_pth):
    model = resnet.resnet_models["cifar"][arch]()
    checkpoint = torch.load(model_pth)
    model.load_state_dict(checkpoint)
    model = model.cuda()
    for p in model.parameters():
        p.requires_grad = False

    return model

def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

if __name__ == '__main__':
    main()

curious-synapse avatar Sep 29 '21 18:09 curious-synapse

Can you try downgrading torch? We verified the accuracy at time of upload.

ddkang avatar Oct 01 '21 07:10 ddkang

I ran the above script again with torch==1.2.0 (which came out in 2019) and got the same results as above. Are the test accuracies plotted in the paper top-5 accuracies instead of top-1?

curious-synapse avatar Oct 03 '21 06:10 curious-synapse