intel-extension-for-pytorch icon indicating copy to clipboard operation
intel-extension-for-pytorch copied to clipboard

Load model trained in XPU and fails to continue training

Open hermanhmchan opened this issue 8 months ago • 5 comments

Describe the bug

I tried to load a model and continue the training (referring to https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_a_general_checkpoint.html#load-the-general-checkpoint). But it gives the error "RuntimeError: Expected all tensors to be on the same device, but found at least two devices, xpu:0 and cpu!"

The following code can reproduce the issue.

import torch
import torchvision

############# code changes ###############
import intel_extension_for_pytorch as ipex

############# code changes ###############

LR = 0.001
DOWNLOAD = True
DATA = "datasets/cifar10/"

transform = torchvision.transforms.Compose(
    [
        torchvision.transforms.Resize((224, 224)),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
)
train_dataset = torchvision.datasets.CIFAR10(
    root=DATA,
    train=True,
    transform=transform,
    download=DOWNLOAD,
)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=128)

model = torchvision.models.resnet50()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR, momentum=0.9)

PATH = "checkpoint.pth"

#################### Load model start #############################
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
#################### Load model end ###############################

model.train()
######################## code changes #######################
model = model.to("xpu")
criterion = criterion.to("xpu")
model, optimizer = ipex.optimize(model, optimizer=optimizer)
######################## code changes #######################

for batch_idx, (data, target) in enumerate(train_loader):
    ########## code changes ##########
    data = data.to("xpu")
    target = target.to("xpu")
    ########## code changes ##########
    optimizer.zero_grad()
    output = model(data)
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()
    print(batch_idx)
torch.save(
    {
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
    },
    PATH,
)

print("Execution finished")

Output as follows:

Files already downloaded and verified
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[1], line 54
     52     loss = criterion(output, target)
     53     loss.backward()
---> 54     optimizer.step()
     55     print(batch_idx)
     56 torch.save(
     57     {
     58         "model_state_dict": model.state_dict(),
   (...)
     61     PATH,
     62 )

File [~/torch/lib/python3.10/site-packages/torch/utils/_contextlib.py:115](http://127.0.0.1:9999/lab/tree/root/project/torch/~/torch/lib/python3.10/site-packages/torch/utils/_contextlib.py#line=114), in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File [~/torch/lib/python3.10/site-packages/intel_extension_for_pytorch/optim/_functional.py:521](http://127.0.0.1:9999/lab/tree/root/project/torch/~/torch/lib/python3.10/site-packages/intel_extension_for_pytorch/optim/_functional.py#line=520), in sgd_step(self, closure)
    518         param2 = get_param2(p, self.params_attr)
    519         params2.append(param2)
--> 521 sgd(
    522     params_with_grad,
    523     params2,
    524     d_p_list,
    525     momentum_buffer_list,
    526     weight_decay=group["weight_decay"],
    527     momentum=group["momentum"],
    528     lr=group["lr"],
    529     dampening=group["dampening"],
    530     nesterov=group["nesterov"],
    531     maximize=group["maximize"],
    532     has_sparse_grad=has_sparse_grad,
    533     foreach=group["foreach"],
    534     fused=self.fused,
    535 )
    537 # update momentum_buffers in state
    538 for p, momentum_buffer in zip(params_with_grad, momentum_buffer_list):

File [~/torch/lib/python3.10/site-packages/intel_extension_for_pytorch/optim/_functional.py:464](http://127.0.0.1:9999/lab/tree/root/project/torch/~/torch/lib/python3.10/site-packages/intel_extension_for_pytorch/optim/_functional.py#line=463), in sgd(params, params2, d_p_list, momentum_buffer_list, has_sparse_grad, foreach, weight_decay, momentum, lr, dampening, nesterov, maximize, fused)
    461 else:
    462     func = _single_tensor_sgd
--> 464 func(
    465     params,
    466     params2,
    467     d_p_list,
    468     momentum_buffer_list,
    469     weight_decay=weight_decay,
    470     momentum=momentum,
    471     lr=lr,
    472     dampening=dampening,
    473     nesterov=nesterov,
    474     has_sparse_grad=has_sparse_grad,
    475     maximize=maximize,
    476     fused=fused,
    477 )

File [~/torch/lib/python3.10/site-packages/intel_extension_for_pytorch/optim/_functional.py:319](http://127.0.0.1:9999/lab/tree/root/project/torch/~/torch/lib/python3.10/site-packages/intel_extension_for_pytorch/optim/_functional.py#line=318), in _single_tensor_sgd(params, params2, grads, momentum_buffer_list, weight_decay, momentum, lr, dampening, nesterov, maximize, has_sparse_grad, fused)
    317 grad = grads[i] if not maximize else -grads[i]
    318 if not grad.is_sparse:
--> 319     momentum_buffer_list[i] = torch.ops.torch_ipex.sgd_fused_step(
    320         param,
    321         grad,
    322         momentum_buffer_list[i],
    323         params2[i],
    324         momentum,
    325         lr,
    326         weight_decay,
    327         dampening,
    328         nesterov,
    329     )
    330     continue
    332 if (
    333     param.dtype == torch.bfloat16
    334     and grad.is_sparse
   (...)
    338 ):
    339     # packed_add can support sparse tensor

File [~/torch/lib/python3.10/site-packages/torch/_ops.py:692](http://127.0.0.1:9999/lab/tree/root/project/torch/~/torch/lib/python3.10/site-packages/torch/_ops.py#line=691), in OpOverloadPacket.__call__(self, *args, **kwargs)
    687 def __call__(self, *args, **kwargs):
    688     # overloading __call__ to ensure torch.ops.foo.bar()
    689     # is still callable from JIT
    690     # We save the function ptr as the `op` attribute on
    691     # OpOverloadPacket to access it here.
--> 692     return self._op(*args, **kwargs or {})

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, xpu:0 and cpu!

Versions

Collecting environment information... PyTorch version: 2.1.0.post2+cxx11.abi PyTorch CXX11 ABI: Yes IPEX version: 2.1.30+xpu IPEX commit: 474a6b3cb Build type: Release

OS: Ubuntu 22.04.4 LTS (x86_64) GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 Clang version: N/A IGC version: N/A CMake version: N/A Libc version: glibc-2.35

Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime) Python platform: Linux-5.15.146.1-microsoft-standard-WSL2-x86_64-with-glibc2.35 Is XPU available: True DPCPP runtime version: N/A MKL version: N/A GPU models and configuration: [0] _DeviceProperties(name='Intel(R) Graphics [0x7d55]', platform_name='Intel(R) Level-Zero', dev_type='gpu', driver_version='1.3.27642', has_fp64=1, total_memory=30234MB, max_compute_units=128, gpu_eu_count=128) Intel OpenCL ICD version: 23.43.27642.40-803~22.04 Level Zero version: 1.3.27642.40-803~22.04

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 46 bits physical, 48 bits virtual Byte Order: Little Endian CPU(s): 22 On-line CPU(s) list: 0-21 Vendor ID: GenuineIntel Model name: Intel(R) Core(TM) Ultra 7 155H CPU family: 6 Model: 170 Thread(s) per core: 2 Core(s) per socket: 11 Socket(s): 1 Stepping: 4 BogoMIPS: 5990.39 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology tsc_reliable nonstop_tsc cpuid pni pclmulqdq vmx ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves avx_vnni umip waitpkg gfni vaes vpclmulqdq rdpid movdiri movdir64b fsrm md_clear serialize flush_l1d arch_capabilities Virtualization: VT-x Hypervisor vendor: Microsoft Virtualization type: full L1d cache: 528 KiB (11 instances) L1i cache: 704 KiB (11 instances) L2 cache: 22 MiB (11 instances) L3 cache: 24 MiB (1 instance) Vulnerability Gather data sampling: Not affected Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Retbleed: Mitigation; Enhanced IBRS Vulnerability Spec rstack overflow: Not affected Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected

Versions of relevant libraries: [pip3] intel-extension-for-pytorch==2.1.30+xpu [pip3] numpy==1.26.4 [pip3] torch==2.1.0.post2+cxx11.abi [pip3] torchaudio==2.1.0.post2+cxx11.abi [pip3] torchvision==0.16.0.post2+cxx11.abi [conda] N/A

hermanhmchan avatar Jun 04 '24 12:06 hermanhmchan