intel-extension-for-pytorch icon indicating copy to clipboard operation
intel-extension-for-pytorch copied to clipboard

Training with JIT is weird

Open huaxz1986 opened this issue 2 years ago • 6 comments

I trained Resnet50 On CIFAR10, and found that the result is weird. When training with IPEX and JIT, the training speedup is up to 4x, but the validation acc is regression to 28% (compared to original 95%).

1 2

The key code is list below.

print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(
    root=data_dir, train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=128, shuffle=True, num_workers=1)
train_epoch_steps = len(trainloader)
for (data,label) in trainloader:
        print(data.size())
        break

testset = torchvision.datasets.CIFAR10(
    root=data_dir, train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=1)
test_epoch_steps = len(testloader)

net = ResNet50()
net = jit.trace(net,torch.rand(128,3,32,32))
net = net.to(memory_format=torch.channels_last)
optimizer = optim.SGD(net.parameters(), lr=args.lr,
                      momentum=0.9, weight_decay=5e-4)
net,optimizer = ipex.optimize(net,optimizer=optimizer)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

huaxz1986 avatar Sep 08 '22 03:09 huaxz1986

Hi, Torchscript is normally not supposed to be used in the training scenario.

jingxu10 avatar Sep 08 '22 04:09 jingxu10

@jingxu10 , as @huaxz1986 description, the jit model could work properly with stock pytroch. So it should be an issue that IPEX introcudes accuracy loss.

EikanWang avatar Sep 08 '22 05:09 EikanWang

Also, ipex.optimize doesn't work with JIT modules. Users are recommended to do ipex.optimize on imperative modules first before turning them to JIT modules via tracing.

jgong5 avatar Sep 08 '22 05:09 jgong5

doesn't work with JIT modules. Users are recommended to do ipex.optimize on imperative modules first before turning them to JIT modules via tracing.

According to the issue, it seems like the ipex.optimize introduces the accuracy loss for the jit module. The expected behavior should be ipex.optimize does nothing to the jit module. Otherwise, we should prompt the user with a message like "Please feed non-jit module to optimize".

EikanWang avatar Sep 10 '22 06:09 EikanWang

will check this issue.

jingxu10 avatar Sep 12 '22 01:09 jingxu10

Please try invoking 'ipex.optimize()' before 'jit.trace()'.

jingxu10 avatar Oct 07 '22 21:10 jingxu10