`torch.linalg.eigh` is significantly slower than expected on Max Series GPU
Describe the issue
Similarly to #428, I tried torch.linalg.eigh on a Max Series GPU using the Intel Devcloud and packages from the intel conda channel, the performance on XPU is not much better than on CPU:
>>> import intel_extension_for_pytorch
>>> import torch
>>> intel_extension_for_pytorch.__version__
'2.0.110+xpu'
>>> torch.__version__
'2.0.1a0+cxx11.abi'
>>> X = torch.randn(500, 500)
>>> X_xpu = X.to("xpu")
>>> %time C = X.T @ X
CPU times: user 938 ms, sys: 76.8 ms, total: 1.01 s
Wall time: 115 ms
>>> %time C_xpu = X_xpu.T @ X_xpu
CPU times: user 4.37 ms, sys: 4 µs, total: 4.37 ms
Wall time: 4.21 ms
So GEMM is around 20x faster on the XPU device that on the CPU host.
However, torch.linalg.eigh is not faster when using the XPU, which is quite unexpected given the speed difference for GEMM.
>>> %time _ = torch.linalg.eigh(C)
CPU times: user 2min 30s, sys: 10.2 s, total: 2min 40s
Wall time: 6.89 s
>>> %time _ = torch.linalg.eigh(C_xpu)
CPU times: user 4min 1s, sys: 14.5 s, total: 4min 15s
Wall time: 5.52 s
More information about the runtime environment of this session:
>>> from pprint import pprint
>>> pprint(dpctl.get_devices())
[<dpctl.SyclDevice [backend_type.opencl, device_type.cpu, Intel(R) Xeon(R) Platinum 8480+] at 0x1472aac521f0>,
<dpctl.SyclDevice [backend_type.opencl, device_type.accelerator, Intel(R) FPGA Emulation Device] at 0x1472a80a9ef0>,
<dpctl.SyclDevice [backend_type.level_zero, device_type.gpu, Intel(R) Data Center GPU Max 1100] at 0x1472a80a9df0>]
>>> import joblib
>>> joblib.cpu_count(only_physical_cores=True)
112
>>> import threadpoolctl
>>> pprint(threadpoolctl.threadpool_info())
[{'filepath': '/home/u103854/mambaforge/envs/intel/lib/libmkl_rt.so.2',
'internal_api': 'mkl',
'num_threads': 112,
'prefix': 'libmkl_rt',
'threading_layer': 'intel',
'user_api': 'blas',
'version': '2023.2-Product'},
{'filepath': '/home/u103854/mambaforge/envs/intel/lib/libiomp5.so',
'internal_api': 'openmp',
'num_threads': 112,
'prefix': 'libiomp',
'user_api': 'openmp',
'version': None},
{'filepath': '/home/u103854/mambaforge/envs/intel/lib/libgomp.so.1.0.0',
'internal_api': 'openmp',
'num_threads': 112,
'prefix': 'libgomp',
'user_api': 'openmp',
'version': None}]
Furthermore, all those numbers are extremely slow for such a small dataset.
Here is the output of a similar experiment on my local laptop (Apple M1):
>>> import torch
>>> X = torch.randn(500, 500)
>>> %time C = X.T @ X
CPU times: user 247 µs, sys: 718 µs, total: 965 µs
Wall time: 4.5 ms
>>> %time _ = torch.linalg.eigh(C)
CPU times: user 12.3 ms, sys: 6.88 ms, total: 19.2 ms
Wall time: 20.6 ms
Here is the output of mamba list for this env:
# packages in environment at /home/u103854/mambaforge/envs/intel:
#
# Name Version Build Channel
_libgcc_mutex 0.1 conda_forge intel
_openmp_mutex 4.5 2_gnu intel
asttokens 2.2.1 pyhd8ed1ab_0 conda-forge
backcall 0.2.0 pyh9f0ad1d_0 conda-forge
backports 1.0 pyhd8ed1ab_3 conda-forge
backports.functools_lru_cache 1.6.5 pyhd8ed1ab_0 conda-forge
brotli 1.0.9 h166bdaf_8 intel
brotli-bin 1.0.9 h166bdaf_8 intel
bzip2 1.0.8 hb9a14ef_9 intel
ca-certificates 2023.7.22 hbcca054_0 conda-forge
certifi 2023.7.22 pyhd8ed1ab_0 conda-forge
charset-normalizer 3.1.0 pyhd8ed1ab_0 intel
daal4py 2023.2.1 py310_intel_32 intel
dal 2023.2.1 intel_32 intel
decorator 5.0.9 pyhd3eb1b0_0 intel
dpcpp-cpp-rt 2023.2.0 intel_49495 intel
dpcpp_cpp_rt 2023.2.0 intel_49495 intel
dpctl 0.14.5 py310he78b74f_24 intel
dpnp 0.12.1 pypi_0 pypi
executing 1.2.0 pyhd8ed1ab_0 conda-forge
filelock 3.6.0 pyhd3eb1b0_0 intel
fortran_rt 2023.2.0 intel_49495 intel
icc_rt 2023.2.0 intel_49495 intel
idna 3.4 pyhd8ed1ab_0 intel
impi_rt 2021.10.0 intel_49371 intel
intel-cmplr-lib-rt 2023.2.0 intel_49495 intel
intel-cmplr-lic-rt 2023.2.0 intel_49495 intel
intel-extension-for-pytorch 2.0.110 py310_xpu_0 intel
intel-fortran-rt 2023.2.0 intel_49495 intel
intel-opencl-rt 2023.2.0 intel_49495 intel
intel-openmp 2023.2.0 intel_49495 intel
intelpython 2023.2.0 0 intel
ipython 8.14.0 pyh41d4057_0 conda-forge
jedi 0.19.0 pyhd8ed1ab_0 conda-forge
jinja2 3.0.1 pyhd3eb1b0_0 intel
joblib 1.2.0 pyh3f38642_0 intel
lark-parser 0.9.0 pyh9f0ad1d_0 intel
level-zero 1.11.0 h00ab1b0_0 intel
libbrotlicommon 1.0.9 h166bdaf_8 intel
libbrotlidec 1.0.9 h166bdaf_8 intel
libbrotlienc 1.0.9 h166bdaf_8 intel
libffi 3.4.2 h7f98852_5 intel
libgcc-ng 12.2.0 h65d4601_19 intel
libgomp 12.2.0 h65d4601_19 intel
libnsl 2.0.0 h7f98852_0 intel
libsqlite 3.42.0 h2797004_0 intel
libstdcxx-ng 12.2.0 h46fd767_19 intel
libuuid 2.38.1 h0b41bf4_0 intel
libuv 1.40.0 h7b6447c_2 intel
libzlib 1.2.13 hd590300_5 intel
markupsafe 2.1.3 py310h2372a71_0 conda-forge
matplotlib-inline 0.1.6 pyhd8ed1ab_0 conda-forge
mkl 2023.2.0 intel_49495 intel
mkl-dpcpp 2023.2.0 intel_49495 intel
mkl-service 2.4.0 py310hae59892_35 intel
mkl_fft 1.3.6 py310h173b8ae_56 intel
mkl_random 1.2.2 py310h1595b48_76 intel
mkl_umath 0.1.1 py310hd987cd3_86 intel
mpi4py 3.1.4 py310h618b5fa_0 intel
mpmath 1.3.0 pyhd8ed1ab_0 conda-forge
ncurses 6.4 hcb278e6_0 intel
networkx 2.6.2 pyhd3eb1b0_2 intel
numpy 1.24.3 py310hed7eef7_0 intel
numpy-base 1.24.3 py310he88ecf9_0 intel
openssl 3.1.3 hd590300_0 conda-forge
packaging 23.1 pyhd8ed1ab_0 intel
parso 0.8.3 pyhd8ed1ab_0 conda-forge
pexpect 4.8.0 pyh1a96a4e_2 conda-forge
pickleshare 0.7.5 py_1003 conda-forge
pip 23.1.2 pyhd8ed1ab_0 intel
platformdirs 3.6.0 pyhd8ed1ab_0 intel
pooch 1.7.0 pyha770c72_3 intel
prompt-toolkit 3.0.39 pyha770c72_0 conda-forge
prompt_toolkit 3.0.39 hd8ed1ab_0 conda-forge
psutil 5.9.5 py310h1fa729e_0 conda-forge
ptyprocess 0.7.0 pyhd3deb0d_0 conda-forge
pure_eval 0.2.2 pyhd8ed1ab_0 conda-forge
pygments 2.16.1 pyhd8ed1ab_0 conda-forge
pysocks 1.7.1 pyha2e5f31_6 intel
python 3.10.12 hef7c979_1 intel
python_abi 3.10 2_cp310 intel
pytorch 2.0.1 py310_xpu_0 intel
readline 8.2 h8228510_1 intel
requests 2.31.0 pyhd8ed1ab_0 intel
scikit-learn 1.2.2 py310hf7d194e_2 intel
scikit-learn-intelex 2023.2.1 py310_intel_32 intel
scipy 1.10.1 py310h01e2e1b_0 intel
setuptools 67.7.2 pyhd8ed1ab_0 intel
six 1.16.0 pyhd3eb1b0_1 intel
stack_data 0.6.2 pyhd8ed1ab_0 conda-forge
sympy 1.12 pyh04b8f61_3 conda-forge
tbb 2021.10.0 intel_49541 intel
tbb4py 2021.10.0 py310_intel_49541 intel
threadpoolctl 3.1.0 pyh8a188c0_0 intel
tk 8.6.12 h1ccaba5_0 intel
traitlets 5.9.0 pyhd8ed1ab_0 conda-forge
typing-extensions 4.6.3 hd8ed1ab_0 intel
typing_extensions 4.6.3 pyha770c72_0 intel
tzdata 2023c h71feb2d_0 intel
urllib3 2.0.3 pyhd8ed1ab_0 intel
wcwidth 0.2.6 pyhd8ed1ab_0 conda-forge
wheel 0.40.0 pyhd8ed1ab_0 intel
xz 5.2.8 h5eee18b_0 intel
zlib 1.2.13 hd590300_5 intel
Thank you for reporting this, we are investigating the issue
@gujinghui @tye1