JAX is not finding cuDNN even though it is on LD_LIBRARY_PATH
Description
I have a project which is relatively simple at the moment, yet whenever a convolution is called the following error is raised:
jaxlib.xla_extension.XlaRuntimeError: UNIMPLEMENTED: DNN library is not found.
I tried to create a minimal reproducible example, but the error message is different for some reason. Here is the example:
from jax.lax import conv
import jax.random as jrd
rng = jrd.PRNGKey(0)
x = jrd.truncated_normal(rng, -1, 1, (1, 1, 28, 28))
kernel = jrd.truncated_normal(rng, -1, 1, (2, 1, 3, 3))
y = conv(x, kernel, window_strides=(1, 1), padding='SAME')
Here is the full output when this script is run:
2022-09-22 13:34:04.032810: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc:727] None of the algorithms provided by cuDNN heuristics worked; trying fallback algorithms. Conv: (f32[1,2,28,28]{3,2,1,0}, u8[0]{0}) custom-call(f32[1,1,28,28]{3,2,1,0}, f32[2,1,3,3]{3,2,1,0}), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config="{\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}"
2022-09-22 13:34:04.112949: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc:930] Failed to determine best cudnn convolution algorithm for:
%cudnn-conv = (f32[1,2,28,28]{3,2,1,0}, u8[0]{0}) custom-call(f32[1,1,28,28]{3,2,1,0} %Arg_0.1, f32[2,1,3,3]{3,2,1,0} %Arg_1.2), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", metadata={op_name="jit(conv_general_dilated)/jit(main)/conv_general_dilated[window_strides=(1, 1) padding=((1, 1), (1, 1)) lhs_dilation=(1, 1) rhs_dilation=(1, 1) dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 1, 2, 3), rhs_spec=(0, 1, 2, 3), out_spec=(0, 1, 2, 3)) feature_group_count=1 batch_group_count=1 lhs_shape=(1, 1, 28, 28) rhs_shape=(2, 1, 3, 3) precision=None preferred_element_type=None]" source_file="/home/medusa/Projects/repros/jaxconv.py" source_line=5}, backend_config="{\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}"
Original error: INTERNAL: All algorithms tried for %cudnn-conv = (f32[1,2,28,28]{3,2,1,0}, u8[0]{0}) custom-call(f32[1,1,28,28]{3,2,1,0} %Arg_0.1, f32[2,1,3,3]{3,2,1,0} %Arg_1.2), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", metadata={op_name="jit(conv_general_dilated)/jit(main)/conv_general_dilated[window_strides=(1, 1) padding=((1, 1), (1, 1)) lhs_dilation=(1, 1) rhs_dilation=(1, 1) dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 1, 2, 3), rhs_spec=(0, 1, 2, 3), out_spec=(0, 1, 2, 3)) feature_group_count=1 batch_group_count=1 lhs_shape=(1, 1, 28, 28) rhs_shape=(2, 1, 3, 3) precision=None preferred_element_type=None]" source_file="/home/medusa/Projects/repros/jaxconv.py" source_line=5}, backend_config="{\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}" failed. Falling back to default algorithm. Per-algorithm errors:
Profiling failure on cuDNN engine 1#TC: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
Profiling failure on cuDNN engine 1: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
Profiling failure on cuDNN engine 0#TC: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
Profiling failure on cuDNN engine 0: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
Profiling failure on cuDNN engine 2#TC: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
Profiling failure on cuDNN engine 2: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
Profiling failure on cuDNN engine 4#TC: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
Profiling failure on cuDNN engine 4: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
Profiling failure on cuDNN engine 6#TC: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
Profiling failure on cuDNN engine 6: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
Profiling failure on cuDNN engine 5#TC: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
Profiling failure on cuDNN engine 5: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
Profiling failure on cuDNN engine 7#TC: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
Profiling failure on cuDNN engine 7: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
Profiling failure on cuDNN engine 1#TC: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
Profiling failure on cuDNN engine 1: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
Profiling failure on cuDNN engine 0#TC: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
Profiling failure on cuDNN engine 0: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
Profiling failure on cuDNN engine 2#TC: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
Profiling failure on cuDNN engine 2: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
Profiling failure on cuDNN engine 4#TC: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
Profiling failure on cuDNN engine 4: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
Profiling failure on cuDNN engine 6#TC: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
Profiling failure on cuDNN engine 6: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
Profiling failure on cuDNN engine 5#TC: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
Profiling failure on cuDNN engine 5: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
Profiling failure on cuDNN engine 7#TC: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
Profiling failure on cuDNN engine 7: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
As a result, convolution performance may be suboptimal.
2022-09-22 13:34:04.118887: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2130] Execution of replica 0 failed: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
Traceback (most recent call last):
File "/home/medusa/Projects/repros/jaxconv.py", line 7, in <module>
y = conv(x, kernel, window_strides=(1, 1), padding='SAME')
File "/home/medusa/anaconda3/envs/critrep/lib/python3.10/site-packages/jax/_src/lax/convolution.py", line 194, in conv
return conv_general_dilated(lhs, rhs, window_strides, padding,
File "/home/medusa/anaconda3/envs/critrep/lib/python3.10/site-packages/jax/_src/lax/convolution.py", line 158, in conv_general_dilated
return conv_general_dilated_p.bind(
File "/home/medusa/anaconda3/envs/critrep/lib/python3.10/site-packages/jax/core.py", line 325, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/home/medusa/anaconda3/envs/critrep/lib/python3.10/site-packages/jax/core.py", line 328, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/home/medusa/anaconda3/envs/critrep/lib/python3.10/site-packages/jax/core.py", line 686, in process_primitive
return primitive.impl(*tracers, **params)
File "/home/medusa/anaconda3/envs/critrep/lib/python3.10/site-packages/jax/_src/dispatch.py", line 113, in apply_primitive
return compiled_fun(*args)
File "/home/medusa/anaconda3/envs/critrep/lib/python3.10/site-packages/jax/_src/dispatch.py", line 198, in <lambda>
return lambda *args, **kw: compiled(*args, **kw)[0]
File "/home/medusa/anaconda3/envs/critrep/lib/python3.10/site-packages/jax/_src/dispatch.py", line 837, in _execute_compiled
out_flat = compiled.execute(in_flat)
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
It's a really strange and verbose error for a convolution that should be exceedingly simple. This occurs both with and without export XLA_FLAGS="--xla_gpu_strict_conv_algorithm_picker=false" commented in my bashrc.
Based on my research of other github issues, the basic cause of this should be that cuDNN is just not installed. However, I followed the instructions for installing it from nvidia, and it is indeed present:
$ echo $LD_LIBRARY_PATH
:/usr/local/cuda-11.6/include:/usr/local/cuda-11.6/lib64:/usr/local/cuda-11.6/include:/usr/local/cuda-11.6/lib64
Indeed, inside /usr/local/cuda-11.6/lib64 is libcudnn.so and libcudnn.so.8, as well as a lot of other shared object files for cuDNN. It seems either there is a problem with what JAX is looking for or I am missing some important file. I'm guessing its the latter, so any list of files I should double check for would be extremely helpful.
What jax/jaxlib version are you using?
jax v0.3.17, jaxlib v0.3.15
Which accelerator(s) are you using?
GPU
Additional System Info
Python 3.10.4 / Ubuntu 21.10
What GPU do you have, and what is the output of nvidia-smi?
GeForce RTX 2060
$ nvidia-smi
Fri Sep 23 15:54:49 2022
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.47.03 Driver Version: 510.47.03 CUDA Version: 11.6 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA GeForce ... On | 00000000:01:00.0 On | N/A |
| N/A 41C P8 11W / N/A | 517MiB / 6144MiB | 20% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| 0 N/A N/A 2578 G /usr/lib/xorg/Xorg 302MiB |
| 0 N/A N/A 2712 G /usr/bin/gnome-shell 52MiB |
| 0 N/A N/A 3707 G ...0/usr/lib/firefox/firefox 152MiB |
| 0 N/A N/A 36453 G ...RendererForSitePerProcess 8MiB |
+-----------------------------------------------------------------------------+
By the way, no such error occurs with regular matrix multiplication, and the following script distinctively increases GPU utilization:
import jax.numpy as jnp
import jax.random as jrd
rng = jrd.PRNGKey(0)
x = jrd.truncated_normal(rng, -1, 1, (100000,))
kernel = jrd.truncated_normal(rng, -1, 1, (1000, 100000))
y = jnp.matmul(kernel, x)
Try an experiment for me: try disabling preallocation, lowering the memory fraction, or using the platform allocator as described here: https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
Does that help?
@hawkinsp Phenomenal! Disabling preallocation solves the issue!
I was aware that these flags were available if too much memory was being preallocated with Xorg (etc) active in the background, but I would have never made the connection with being unable to load cuDNN. What is the root cause here / how did you know to try that?
The underlying problem is that CuDNN will fail to initialize if JAX grabs too much GPU memory.
There are two things we can do to improve matters here: a) improve the error message to suggest this as a possible cause, and b) change the JAX allocator design a bit.
(b) is longer term work but (a) we should be able to do sooner.