jax-windows-builder
jax-windows-builder copied to clipboard
Error with CUDNN 8.9.2
Hi,
I just installed [cuda/jaxlib-0.4.11+cuda.cudnn89-cp311-cp311-win_amd64.whl](https://whls.blob.core.windows.net/unstable/cuda/jaxlib-0.4.11+cuda.cudnn89-cp311-cp311-win_amd64.whl)
. I have CUDA 11.7 and CUDNN 8.9.2. When I run this:
import jax
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)
jax.numpy.array(1.0)
I get the correct print output but then an error:
gpu
2023-07-03 13:06:41.867294: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:407] There was an error before creating cudnn handle (302): cudaGetErrorName symbol not found. : cudaGetErrorString symbol not found.
Traceback (most recent call last):
File "S:\dev\tapnet\projects\test.py", line 10, in <module>
jax.numpy.array(1.0)
File "S:\dev\tapnet\Lib\site-packages\jax\_src\numpy\lax_numpy.py", line 2051, in array
out_array: Array = lax_internal._convert_element_type(out, dtype, weak_type=weak_type)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "S:\dev\tapnet\Lib\site-packages\jax\_src\lax\lax.py", line 549, in _convert_element_type
return convert_element_type_p.bind(operand, new_dtype=new_dtype,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "S:\dev\tapnet\Lib\site-packages\jax\_src\core.py", line 380, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "S:\dev\tapnet\Lib\site-packages\jax\_src\core.py", line 383, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "S:\dev\tapnet\Lib\site-packages\jax\_src\core.py", line 815, in process_primitive
return primitive.impl(*tracers, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "S:\dev\tapnet\Lib\site-packages\jax\_src\dispatch.py", line 132, in apply_primitive
compiled_fun = xla_primitive_callable(
^^^^^^^^^^^^^^^^^^^^^^^
File "S:\dev\tapnet\Lib\site-packages\jax\_src\util.py", line 284, in wrapper
return cached(config._trace_context(), *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "S:\dev\tapnet\Lib\site-packages\jax\_src\util.py", line 277, in cached
return f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^
File "S:\dev\tapnet\Lib\site-packages\jax\_src\dispatch.py", line 223, in xla_primitive_callable
compiled = _xla_callable_uncached(
^^^^^^^^^^^^^^^^^^^^^^^
File "S:\dev\tapnet\Lib\site-packages\jax\_src\dispatch.py", line 253, in _xla_callable_uncached
return computation.compile().unsafe_call
^^^^^^^^^^^^^^^^^^^^^
File "S:\dev\tapnet\Lib\site-packages\jax\_src\interpreters\pxla.py", line 2323, in compile
executable = UnloadedMeshExecutable.from_hlo(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "S:\dev\tapnet\Lib\site-packages\jax\_src\interpreters\pxla.py", line 2645, in from_hlo
xla_executable, compile_options = _cached_compilation(
^^^^^^^^^^^^^^^^^^^^
File "S:\dev\tapnet\Lib\site-packages\jax\_src\interpreters\pxla.py", line 2555, in _cached_compilation
xla_executable = dispatch.compile_or_get_cached(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "S:\dev\tapnet\Lib\site-packages\jax\_src\dispatch.py", line 497, in compile_or_get_cached
return backend_compile(backend, computation, compile_options,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "S:\dev\tapnet\Lib\site-packages\jax\_src\profiler.py", line 314, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "S:\dev\tapnet\Lib\site-packages\jax\_src\dispatch.py", line 465, in backend_compile
return backend.compile(built_c, compile_options=options)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
Any idea where should I start looking into? Thanks!
This is not released, they are some side effect of overhaul branch. Please don't use them at the moment.
The combination of the build is cuda 12.1 + cudnn 8.9.
https://github.com/cloudhan/jax-windows-builder/blob/c53944319dd133592dfdd6593a5f73321fc45e83/build-jaxlib.ps1#L35-L42
Thanks for your reply. So 12.1+8.9.1 and 11.8+8.6.0 would be an option?
Yes.
Thank you!