intel-extension-for-pytorch
intel-extension-for-pytorch copied to clipboard
Training with JIT is weird
I trained Resnet50 On CIFAR10, and found that the result is weird. When training with IPEX and JIT, the training speedup is up to 4x, but the validation acc is regression to 28% (compared to original 95%).


The key code is list below.
print('==> Preparing data..')
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
trainset = torchvision.datasets.CIFAR10(
root=data_dir, train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(
trainset, batch_size=128, shuffle=True, num_workers=1)
train_epoch_steps = len(trainloader)
for (data,label) in trainloader:
print(data.size())
break
testset = torchvision.datasets.CIFAR10(
root=data_dir, train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
testset, batch_size=100, shuffle=False, num_workers=1)
test_epoch_steps = len(testloader)
net = ResNet50()
net = jit.trace(net,torch.rand(128,3,32,32))
net = net.to(memory_format=torch.channels_last)
optimizer = optim.SGD(net.parameters(), lr=args.lr,
momentum=0.9, weight_decay=5e-4)
net,optimizer = ipex.optimize(net,optimizer=optimizer)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
Hi, Torchscript is normally not supposed to be used in the training scenario.
@jingxu10 , as @huaxz1986 description, the jit model could work properly with stock pytroch. So it should be an issue that IPEX introcudes accuracy loss.
Also, ipex.optimize
doesn't work with JIT modules. Users are recommended to do ipex.optimize
on imperative modules first before turning them to JIT modules via tracing.
doesn't work with JIT modules. Users are recommended to do
ipex.optimize
on imperative modules first before turning them to JIT modules via tracing.
According to the issue, it seems like the ipex.optimize
introduces the accuracy loss for the jit module. The expected behavior should be ipex.optimize
does nothing to the jit module. Otherwise, we should prompt the user with a message like "Please feed non-jit module to optimize".
will check this issue.
Please try invoking 'ipex.optimize()' before 'jit.trace()'.