intel-extension-for-pytorch
intel-extension-for-pytorch copied to clipboard
Load model trained in XPU and fails to continue training
Describe the bug
I tried to load a model and continue the training (referring to https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_a_general_checkpoint.html#load-the-general-checkpoint). But it gives the error "RuntimeError: Expected all tensors to be on the same device, but found at least two devices, xpu:0 and cpu!"
The following code can reproduce the issue.
import torch
import torchvision
############# code changes ###############
import intel_extension_for_pytorch as ipex
############# code changes ###############
LR = 0.001
DOWNLOAD = True
DATA = "datasets/cifar10/"
transform = torchvision.transforms.Compose(
[
torchvision.transforms.Resize((224, 224)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
train_dataset = torchvision.datasets.CIFAR10(
root=DATA,
train=True,
transform=transform,
download=DOWNLOAD,
)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=128)
model = torchvision.models.resnet50()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR, momentum=0.9)
PATH = "checkpoint.pth"
#################### Load model start #############################
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
#################### Load model end ###############################
model.train()
######################## code changes #######################
model = model.to("xpu")
criterion = criterion.to("xpu")
model, optimizer = ipex.optimize(model, optimizer=optimizer)
######################## code changes #######################
for batch_idx, (data, target) in enumerate(train_loader):
########## code changes ##########
data = data.to("xpu")
target = target.to("xpu")
########## code changes ##########
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
print(batch_idx)
torch.save(
{
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
},
PATH,
)
print("Execution finished")
Output as follows:
Files already downloaded and verified
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[1], line 54
52 loss = criterion(output, target)
53 loss.backward()
---> 54 optimizer.step()
55 print(batch_idx)
56 torch.save(
57 {
58 "model_state_dict": model.state_dict(),
(...)
61 PATH,
62 )
File [~/torch/lib/python3.10/site-packages/torch/utils/_contextlib.py:115](http://127.0.0.1:9999/lab/tree/root/project/torch/~/torch/lib/python3.10/site-packages/torch/utils/_contextlib.py#line=114), in context_decorator.<locals>.decorate_context(*args, **kwargs)
112 @functools.wraps(func)
113 def decorate_context(*args, **kwargs):
114 with ctx_factory():
--> 115 return func(*args, **kwargs)
File [~/torch/lib/python3.10/site-packages/intel_extension_for_pytorch/optim/_functional.py:521](http://127.0.0.1:9999/lab/tree/root/project/torch/~/torch/lib/python3.10/site-packages/intel_extension_for_pytorch/optim/_functional.py#line=520), in sgd_step(self, closure)
518 param2 = get_param2(p, self.params_attr)
519 params2.append(param2)
--> 521 sgd(
522 params_with_grad,
523 params2,
524 d_p_list,
525 momentum_buffer_list,
526 weight_decay=group["weight_decay"],
527 momentum=group["momentum"],
528 lr=group["lr"],
529 dampening=group["dampening"],
530 nesterov=group["nesterov"],
531 maximize=group["maximize"],
532 has_sparse_grad=has_sparse_grad,
533 foreach=group["foreach"],
534 fused=self.fused,
535 )
537 # update momentum_buffers in state
538 for p, momentum_buffer in zip(params_with_grad, momentum_buffer_list):
File [~/torch/lib/python3.10/site-packages/intel_extension_for_pytorch/optim/_functional.py:464](http://127.0.0.1:9999/lab/tree/root/project/torch/~/torch/lib/python3.10/site-packages/intel_extension_for_pytorch/optim/_functional.py#line=463), in sgd(params, params2, d_p_list, momentum_buffer_list, has_sparse_grad, foreach, weight_decay, momentum, lr, dampening, nesterov, maximize, fused)
461 else:
462 func = _single_tensor_sgd
--> 464 func(
465 params,
466 params2,
467 d_p_list,
468 momentum_buffer_list,
469 weight_decay=weight_decay,
470 momentum=momentum,
471 lr=lr,
472 dampening=dampening,
473 nesterov=nesterov,
474 has_sparse_grad=has_sparse_grad,
475 maximize=maximize,
476 fused=fused,
477 )
File [~/torch/lib/python3.10/site-packages/intel_extension_for_pytorch/optim/_functional.py:319](http://127.0.0.1:9999/lab/tree/root/project/torch/~/torch/lib/python3.10/site-packages/intel_extension_for_pytorch/optim/_functional.py#line=318), in _single_tensor_sgd(params, params2, grads, momentum_buffer_list, weight_decay, momentum, lr, dampening, nesterov, maximize, has_sparse_grad, fused)
317 grad = grads[i] if not maximize else -grads[i]
318 if not grad.is_sparse:
--> 319 momentum_buffer_list[i] = torch.ops.torch_ipex.sgd_fused_step(
320 param,
321 grad,
322 momentum_buffer_list[i],
323 params2[i],
324 momentum,
325 lr,
326 weight_decay,
327 dampening,
328 nesterov,
329 )
330 continue
332 if (
333 param.dtype == torch.bfloat16
334 and grad.is_sparse
(...)
338 ):
339 # packed_add can support sparse tensor
File [~/torch/lib/python3.10/site-packages/torch/_ops.py:692](http://127.0.0.1:9999/lab/tree/root/project/torch/~/torch/lib/python3.10/site-packages/torch/_ops.py#line=691), in OpOverloadPacket.__call__(self, *args, **kwargs)
687 def __call__(self, *args, **kwargs):
688 # overloading __call__ to ensure torch.ops.foo.bar()
689 # is still callable from JIT
690 # We save the function ptr as the `op` attribute on
691 # OpOverloadPacket to access it here.
--> 692 return self._op(*args, **kwargs or {})
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, xpu:0 and cpu!
Versions
Collecting environment information... PyTorch version: 2.1.0.post2+cxx11.abi PyTorch CXX11 ABI: Yes IPEX version: 2.1.30+xpu IPEX commit: 474a6b3cb Build type: Release
OS: Ubuntu 22.04.4 LTS (x86_64) GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 Clang version: N/A IGC version: N/A CMake version: N/A Libc version: glibc-2.35
Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime) Python platform: Linux-5.15.146.1-microsoft-standard-WSL2-x86_64-with-glibc2.35 Is XPU available: True DPCPP runtime version: N/A MKL version: N/A GPU models and configuration: [0] _DeviceProperties(name='Intel(R) Graphics [0x7d55]', platform_name='Intel(R) Level-Zero', dev_type='gpu', driver_version='1.3.27642', has_fp64=1, total_memory=30234MB, max_compute_units=128, gpu_eu_count=128) Intel OpenCL ICD version: 23.43.27642.40-803~22.04 Level Zero version: 1.3.27642.40-803~22.04
CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 46 bits physical, 48 bits virtual Byte Order: Little Endian CPU(s): 22 On-line CPU(s) list: 0-21 Vendor ID: GenuineIntel Model name: Intel(R) Core(TM) Ultra 7 155H CPU family: 6 Model: 170 Thread(s) per core: 2 Core(s) per socket: 11 Socket(s): 1 Stepping: 4 BogoMIPS: 5990.39 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology tsc_reliable nonstop_tsc cpuid pni pclmulqdq vmx ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves avx_vnni umip waitpkg gfni vaes vpclmulqdq rdpid movdiri movdir64b fsrm md_clear serialize flush_l1d arch_capabilities Virtualization: VT-x Hypervisor vendor: Microsoft Virtualization type: full L1d cache: 528 KiB (11 instances) L1i cache: 704 KiB (11 instances) L2 cache: 22 MiB (11 instances) L3 cache: 24 MiB (1 instance) Vulnerability Gather data sampling: Not affected Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Retbleed: Mitigation; Enhanced IBRS Vulnerability Spec rstack overflow: Not affected Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected
Versions of relevant libraries: [pip3] intel-extension-for-pytorch==2.1.30+xpu [pip3] numpy==1.26.4 [pip3] torch==2.1.0.post2+cxx11.abi [pip3] torchaudio==2.1.0.post2+cxx11.abi [pip3] torchvision==0.16.0.post2+cxx11.abi [conda] N/A