torch-conv-kan icon indicating copy to clipboard operation
torch-conv-kan copied to clipboard

VGG imagenet weights

Open sohaiberrabii opened this issue 4 months ago • 0 comments

Thank you for the amazing work. I have tried to use to VGG weights trained on imagenet (https://huggingface.co/brivangl/vgg_kagn11_v2) But the accuracy remains zero so it doesn't seem to be a normalization discrepancy.

My eval script is straightforward:

import torch
device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu"))
if device.type == "cuda":
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
from torchvision import datasets, transforms

@torch.no_grad()
def test(net, loader, device):
    net = net.to(device)
    net.eval()
    correct = 0
    total = 0
    space_fmt = str(len(str(len(loader)))) + "d"
    for batch_idx, (inputs, targets) in enumerate(loader):
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = net(inputs)
        correct += (outputs.argmax(-1) == targets).sum().item()
        total += targets.size(0)
        if batch_idx % 100 == 0 or batch_idx == len(loader) - 1:
            acc = 100.*correct/total
            print(f"[{batch_idx:{space_fmt}}/{len(loader)}]\tacc: {acc:.3f}\tcorr/total: {correct}/{total}")

    print(f"Acc@1: {acc}")

from models import vggkagn
model = vggkagn(3,
                1000,
                groups=1,
                degree=5,
                dropout=0.15,
                l1_decay=0,
                dropout_linear=0.25,
                width_scale=2,
                vgg_type='VGG11v2',
                expected_feature_shape=(1, 1),
                affine=True,
                )
model.from_pretrained('brivangl/vgg_kagn11_v2')

transforms_val = transforms.Compose([
        transforms.Resize(256, antialias=True),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
testset = datasets.ImageNet(root="/raid/datasets/imagenet", split="val", transform=transforms_val)
testloader = torch.utils.data.DataLoader(testset, batch_size=8, shuffle=False)

test(model, testloader, device)

sohaiberrabii avatar Oct 10 '24 03:10 sohaiberrabii