intel-extension-for-tensorflow
intel-extension-for-tensorflow copied to clipboard
DLPack conversion does not work
System: WSL2 Ubuntu 22.04, on top of Windows 11
CPU: 1270P
GPU: integrated ([ext_oneapi_level_zero:gpu:0] Intel(R) Level-Zero, Intel(R) Graphics [0x46a6] 1.3 [1.3.26032]
)
Tensorflow: 2.12.0
Jax: 0.4.4
import jax.numpy as jnp
jnp.arange(10, dtype=jnp.float32).__dlpack__()
Results into error:
Python 3.10.6 (main, May 29 2023, 11:10:38) [GCC 11.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax.numpy as jnp
>>> jnp.arange(10, dtype=jnp.float32).__dlpack__()
2023-07-06 21:56:46.726166: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:169] XLA service 0x558bc41f0c80 initialized for platform Interpreter (this does not guarantee that XLA will be used). Devices:
2023-07-06 21:56:46.726195: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:177] StreamExecutor device (0): Interpreter, <undefined>
2023-07-06 21:56:46.732904: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.cc:215] TfrtCpuClient created.
2023-07-06 21:56:46.733628: I external/org_tensorflow/tensorflow/compiler/xla/stream_executor/tpu/tpu_initializer_helper.cc:266] Libtpu path is: libtpu.so
2023-07-06 21:56:46.733881: I external/org_tensorflow/tensorflow/compiler/xla/stream_executor/tpu/tpu_platform_interface.cc:73] No TPU platform found.
2023-07-06 21:56:46.993558: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_api.cc:85] GetPjrtApi was found for xpu at /home/yevhenii/Projects/users.yevhenii/examples/jax/libitex_xla_extension.so
2023-07-06 21:56:46.993610: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_api.cc:58] PJRT_Api is set for device type xpu
2023-07-06 21:56:46.994083: I itex/core/devices/gpu/itex_gpu_runtime.cc:129] Selected platform: Intel(R) Level-Zero
2023-07-06 21:56:46.994362: I itex/core/devices/gpu/itex_gpu_runtime.cc:154] number of sub-devices is zero, expose root device.
2023-07-06 21:56:46.998761: I itex/core/compiler/xla/service/service.cc:176] XLA service 0x558bc64650a0 initialized for platform sycl (this does not guarantee that XLA will be used). Devices:
2023-07-06 21:56:46.998790: I itex/core/compiler/xla/service/service.cc:184] StreamExecutor device (0): <undefined>, <undefined>
2023-07-06 21:56:47.000500: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.cc:83] PjRtCApiClient created.
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/yevhenii/.local/share/virtualenvs/numba-dpex-x1V09ZPr/lib/python3.10/site-packages/jax/_src/array.py", line 343, in __dlpack__
return to_dlpack(self)
File "/home/yevhenii/.local/share/virtualenvs/numba-dpex-x1V09ZPr/lib/python3.10/site-packages/jax/_src/dlpack.py", line 51, in to_dlpack
return xla_client._xla.buffer_to_dlpack_managed_tensor(
jaxlib.xla_extension.XlaRuntimeError: UNIMPLEMENTED: PJRT C API does not support AcquireExternalReference
Jax itself works on level zero GPU, so environment is not broken. I guess it is lack of implementation of PJRT C API does not support AcquireExternalReference
. It blocks from users workflows that require both jax related operations and, for example, numba_dpex related operations without memory copying.
Hi @ZzEeKkAa,
Thank you for tring the feature.
We can reproduce this on linux
Yes, you right. There is some unimplemented interface. jaxlib.xla_extension.XlaRuntimeError: UNIMPLEMENTED: PJRT C API does not support AcquireExternalReference
[stream=0x55e8dc9039e0,impl=0x55e8dd530ff0] StreamPool returning ok stream
2023-07-07 14:45:37.326709: I itex/core/compiler/xla/service/gpu/gpu_executable.cc:477] GpuExecutable::ExecuteAsyncOnStreamImpl(jit_iota) time: 129 ms (cumulative: 129 ms, max: 129 ms, #called: 1)
2023-07-07 14:45:37.326719: I itex/core/compiler/xla/pjrt/pjrt_stream_executor_client.cc:1989] Replica 0 partition 0 completed; ok=1
2023-07-07 14:45:37.326884: I itex/core/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2268] Replicated execution complete.
Traceback (most recent call last):
File "/ws1/xigui/itex-source-code/dlpack.py", line 2, in <module>
jnp.arange(10, dtype=jnp.float32).__dlpack__()
File "/home/xiguiwang/ws1/conda/itex-build/lib/python3.9/site-packages/jax/_src/array.py", line 343, in __dlpack__
return to_dlpack(self)
File "/home/xiguiwang/ws1/conda/itex-build/lib/python3.9/site-packages/jax/_src/dlpack.py", line 51, in to_dlpack
return xla_client._xla.buffer_to_dlpack_managed_tensor(
jaxlib.xla_extension.XlaRuntimeError: UNIMPLEMENTED: PJRT C API does not support AcquireExternalReference
2023-07-07 14:45:37.338990: I itex/core/compiler/xla/pjrt/pjrt_stream_executor_client.cc:1208] PjRtStreamExecutorBuffer::Delete
2023-07-07 14:45:37.339017: I itex/core/compiler/xla/stream_executor/stream_executor_pimpl.cc:217] Called StreamExecutor::Deallocate(mem=0xffff81ac00000000) mem->size()=40
2023-07-07 14:45:37.350224: I itex/core/compiler/xla/stream_executor/stream_executor_pimpl.cc:367] Called StreamExecutor::SynchronizeAllActivity()
@yiqianglee @jzhoulon
Do you have idea this "jaxlib.xla_extension.XlaRuntimeError: UNIMPLEMENTED: PJRT C API does not support AcquireExternalReference"?
Is this a problem of "PJRT C API" Interface design of NOT cover/support or a problem of ITEX NOT implementaton?
Xigui
Yes, this is a problem of "PJRT C API" Interface design and we need support from OpenXLA side. https://github.com/openxla/xla/blob/main/xla/pjrt/pjrt_c_api_client.h#L281-L285
@ZzEeKkAa Hope this helps for you.
This works on recent Intel extension for Tensorflow 2.15 + oneAPI 2024.1.
This is the software version I used to verify this problem. intel-extension-for-tensorflow 2.15.0.0 pypi_0 pypi intel-extension-for-tensorflow-lib 2.15.0.0.2 pypi_0 pypi jax 0.4.26 pypi_0 pypi jaxlib 0.4.26 pypi_0 pypi
Here is the document to set up your intel-extension-for-tensorflow: https://intel.github.io/intel-extension-for-tensorflow/v2.15.0.0/docs/install/install_for_xpu.html#intel-xpu-software-installation