jax-windows-builder icon indicating copy to clipboard operation
jax-windows-builder copied to clipboard

Error with CUDNN 8.9.2

Open drperpen opened this issue 1 year ago • 5 comments

Hi, I just installed [cuda/jaxlib-0.4.11+cuda.cudnn89-cp311-cp311-win_amd64.whl](https://whls.blob.core.windows.net/unstable/cuda/jaxlib-0.4.11+cuda.cudnn89-cp311-cp311-win_amd64.whl). I have CUDA 11.7 and CUDNN 8.9.2. When I run this:

import jax
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)
jax.numpy.array(1.0)

I get the correct print output but then an error:

gpu
2023-07-03 13:06:41.867294: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:407] There was an error before creating cudnn handle (302): cudaGetErrorName symbol not found. : cudaGetErrorString symbol not found.
Traceback (most recent call last):
File "S:\dev\tapnet\projects\test.py", line 10, in <module>
jax.numpy.array(1.0)
File "S:\dev\tapnet\Lib\site-packages\jax\_src\numpy\lax_numpy.py", line 2051, in array
out_array: Array = lax_internal._convert_element_type(out, dtype, weak_type=weak_type)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "S:\dev\tapnet\Lib\site-packages\jax\_src\lax\lax.py", line 549, in _convert_element_type
return convert_element_type_p.bind(operand, new_dtype=new_dtype,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "S:\dev\tapnet\Lib\site-packages\jax\_src\core.py", line 380, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "S:\dev\tapnet\Lib\site-packages\jax\_src\core.py", line 383, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "S:\dev\tapnet\Lib\site-packages\jax\_src\core.py", line 815, in process_primitive
return primitive.impl(*tracers, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "S:\dev\tapnet\Lib\site-packages\jax\_src\dispatch.py", line 132, in apply_primitive
compiled_fun = xla_primitive_callable(
^^^^^^^^^^^^^^^^^^^^^^^
File "S:\dev\tapnet\Lib\site-packages\jax\_src\util.py", line 284, in wrapper
return cached(config._trace_context(), *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "S:\dev\tapnet\Lib\site-packages\jax\_src\util.py", line 277, in cached
return f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^
File "S:\dev\tapnet\Lib\site-packages\jax\_src\dispatch.py", line 223, in xla_primitive_callable
compiled = _xla_callable_uncached(
^^^^^^^^^^^^^^^^^^^^^^^
File "S:\dev\tapnet\Lib\site-packages\jax\_src\dispatch.py", line 253, in _xla_callable_uncached
return computation.compile().unsafe_call
^^^^^^^^^^^^^^^^^^^^^
File "S:\dev\tapnet\Lib\site-packages\jax\_src\interpreters\pxla.py", line 2323, in compile
executable = UnloadedMeshExecutable.from_hlo(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "S:\dev\tapnet\Lib\site-packages\jax\_src\interpreters\pxla.py", line 2645, in from_hlo
xla_executable, compile_options = _cached_compilation(
^^^^^^^^^^^^^^^^^^^^
File "S:\dev\tapnet\Lib\site-packages\jax\_src\interpreters\pxla.py", line 2555, in _cached_compilation
xla_executable = dispatch.compile_or_get_cached(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "S:\dev\tapnet\Lib\site-packages\jax\_src\dispatch.py", line 497, in compile_or_get_cached
return backend_compile(backend, computation, compile_options,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "S:\dev\tapnet\Lib\site-packages\jax\_src\profiler.py", line 314, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "S:\dev\tapnet\Lib\site-packages\jax\_src\dispatch.py", line 465, in backend_compile
return backend.compile(built_c, compile_options=options)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

Any idea where should I start looking into? Thanks!

drperpen avatar Jul 03 '23 04:07 drperpen

This is not released, they are some side effect of overhaul branch. Please don't use them at the moment.

cloudhan avatar Jul 03 '23 04:07 cloudhan

The combination of the build is cuda 12.1 + cudnn 8.9.

https://github.com/cloudhan/jax-windows-builder/blob/c53944319dd133592dfdd6593a5f73321fc45e83/build-jaxlib.ps1#L35-L42

cloudhan avatar Jul 03 '23 04:07 cloudhan

Thanks for your reply. So 12.1+8.9.1 and 11.8+8.6.0 would be an option?

drperpen avatar Jul 03 '23 04:07 drperpen

Yes.

cloudhan avatar Jul 03 '23 04:07 cloudhan

Thank you!

drperpen avatar Jul 03 '23 05:07 drperpen