jax icon indicating copy to clipboard operation
jax copied to clipboard

JAX is not finding cuDNN even though it is on LD_LIBRARY_PATH

Open Ebanflo42 opened this issue 3 years ago • 5 comments

Description

I have a project which is relatively simple at the moment, yet whenever a convolution is called the following error is raised:

jaxlib.xla_extension.XlaRuntimeError: UNIMPLEMENTED: DNN library is not found.

I tried to create a minimal reproducible example, but the error message is different for some reason. Here is the example:

from jax.lax import conv
import jax.random as jrd

rng = jrd.PRNGKey(0)
x = jrd.truncated_normal(rng, -1, 1, (1, 1, 28, 28))
kernel = jrd.truncated_normal(rng, -1, 1, (2, 1, 3, 3))
y = conv(x, kernel, window_strides=(1, 1), padding='SAME')

Here is the full output when this script is run:

2022-09-22 13:34:04.032810: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc:727] None of the algorithms provided by cuDNN heuristics worked; trying fallback algorithms.  Conv: (f32[1,2,28,28]{3,2,1,0}, u8[0]{0}) custom-call(f32[1,1,28,28]{3,2,1,0}, f32[2,1,3,3]{3,2,1,0}), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config="{\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}"
2022-09-22 13:34:04.112949: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc:930] Failed to determine best cudnn convolution algorithm for:
%cudnn-conv = (f32[1,2,28,28]{3,2,1,0}, u8[0]{0}) custom-call(f32[1,1,28,28]{3,2,1,0} %Arg_0.1, f32[2,1,3,3]{3,2,1,0} %Arg_1.2), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", metadata={op_name="jit(conv_general_dilated)/jit(main)/conv_general_dilated[window_strides=(1, 1) padding=((1, 1), (1, 1)) lhs_dilation=(1, 1) rhs_dilation=(1, 1) dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 1, 2, 3), rhs_spec=(0, 1, 2, 3), out_spec=(0, 1, 2, 3)) feature_group_count=1 batch_group_count=1 lhs_shape=(1, 1, 28, 28) rhs_shape=(2, 1, 3, 3) precision=None preferred_element_type=None]" source_file="/home/medusa/Projects/repros/jaxconv.py" source_line=5}, backend_config="{\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}"

Original error: INTERNAL: All algorithms tried for %cudnn-conv = (f32[1,2,28,28]{3,2,1,0}, u8[0]{0}) custom-call(f32[1,1,28,28]{3,2,1,0} %Arg_0.1, f32[2,1,3,3]{3,2,1,0} %Arg_1.2), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", metadata={op_name="jit(conv_general_dilated)/jit(main)/conv_general_dilated[window_strides=(1, 1) padding=((1, 1), (1, 1)) lhs_dilation=(1, 1) rhs_dilation=(1, 1) dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 1, 2, 3), rhs_spec=(0, 1, 2, 3), out_spec=(0, 1, 2, 3)) feature_group_count=1 batch_group_count=1 lhs_shape=(1, 1, 28, 28) rhs_shape=(2, 1, 3, 3) precision=None preferred_element_type=None]" source_file="/home/medusa/Projects/repros/jaxconv.py" source_line=5}, backend_config="{\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}" failed. Falling back to default algorithm.  Per-algorithm errors:
  Profiling failure on cuDNN engine 1#TC: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
  Profiling failure on cuDNN engine 1: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
  Profiling failure on cuDNN engine 0#TC: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
  Profiling failure on cuDNN engine 0: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
  Profiling failure on cuDNN engine 2#TC: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
  Profiling failure on cuDNN engine 2: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
  Profiling failure on cuDNN engine 4#TC: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
  Profiling failure on cuDNN engine 4: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
  Profiling failure on cuDNN engine 6#TC: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
  Profiling failure on cuDNN engine 6: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
  Profiling failure on cuDNN engine 5#TC: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
  Profiling failure on cuDNN engine 5: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
  Profiling failure on cuDNN engine 7#TC: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
  Profiling failure on cuDNN engine 7: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
  Profiling failure on cuDNN engine 1#TC: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
  Profiling failure on cuDNN engine 1: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
  Profiling failure on cuDNN engine 0#TC: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
  Profiling failure on cuDNN engine 0: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
  Profiling failure on cuDNN engine 2#TC: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
  Profiling failure on cuDNN engine 2: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
  Profiling failure on cuDNN engine 4#TC: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
  Profiling failure on cuDNN engine 4: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
  Profiling failure on cuDNN engine 6#TC: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
  Profiling failure on cuDNN engine 6: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
  Profiling failure on cuDNN engine 5#TC: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
  Profiling failure on cuDNN engine 5: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
  Profiling failure on cuDNN engine 7#TC: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
  Profiling failure on cuDNN engine 7: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'

As a result, convolution performance may be suboptimal.
2022-09-22 13:34:04.118887: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2130] Execution of replica 0 failed: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
Traceback (most recent call last):
  File "/home/medusa/Projects/repros/jaxconv.py", line 7, in <module>
    y = conv(x, kernel, window_strides=(1, 1), padding='SAME')
  File "/home/medusa/anaconda3/envs/critrep/lib/python3.10/site-packages/jax/_src/lax/convolution.py", line 194, in conv
    return conv_general_dilated(lhs, rhs, window_strides, padding,
  File "/home/medusa/anaconda3/envs/critrep/lib/python3.10/site-packages/jax/_src/lax/convolution.py", line 158, in conv_general_dilated
    return conv_general_dilated_p.bind(
  File "/home/medusa/anaconda3/envs/critrep/lib/python3.10/site-packages/jax/core.py", line 325, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/home/medusa/anaconda3/envs/critrep/lib/python3.10/site-packages/jax/core.py", line 328, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/home/medusa/anaconda3/envs/critrep/lib/python3.10/site-packages/jax/core.py", line 686, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/medusa/anaconda3/envs/critrep/lib/python3.10/site-packages/jax/_src/dispatch.py", line 113, in apply_primitive
    return compiled_fun(*args)
  File "/home/medusa/anaconda3/envs/critrep/lib/python3.10/site-packages/jax/_src/dispatch.py", line 198, in <lambda>
    return lambda *args, **kw: compiled(*args, **kw)[0]
  File "/home/medusa/anaconda3/envs/critrep/lib/python3.10/site-packages/jax/_src/dispatch.py", line 837, in _execute_compiled
    out_flat = compiled.execute(in_flat)
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4023): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'

It's a really strange and verbose error for a convolution that should be exceedingly simple. This occurs both with and without export XLA_FLAGS="--xla_gpu_strict_conv_algorithm_picker=false" commented in my bashrc.

Based on my research of other github issues, the basic cause of this should be that cuDNN is just not installed. However, I followed the instructions for installing it from nvidia, and it is indeed present:

$ echo $LD_LIBRARY_PATH
:/usr/local/cuda-11.6/include:/usr/local/cuda-11.6/lib64:/usr/local/cuda-11.6/include:/usr/local/cuda-11.6/lib64

Indeed, inside /usr/local/cuda-11.6/lib64 is libcudnn.so and libcudnn.so.8, as well as a lot of other shared object files for cuDNN. It seems either there is a problem with what JAX is looking for or I am missing some important file. I'm guessing its the latter, so any list of files I should double check for would be extremely helpful.

What jax/jaxlib version are you using?

jax v0.3.17, jaxlib v0.3.15

Which accelerator(s) are you using?

GPU

Additional System Info

Python 3.10.4 / Ubuntu 21.10

Ebanflo42 avatar Sep 22 '22 11:09 Ebanflo42

What GPU do you have, and what is the output of nvidia-smi?

froystig avatar Sep 22 '22 22:09 froystig

GeForce RTX 2060

$ nvidia-smi
Fri Sep 23 15:54:49 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.47.03    Driver Version: 510.47.03    CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA GeForce ...  On   | 00000000:01:00.0  On |                  N/A |
| N/A   41C    P8    11W /  N/A |    517MiB /  6144MiB |     20%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A      2578      G   /usr/lib/xorg/Xorg                302MiB |
|    0   N/A  N/A      2712      G   /usr/bin/gnome-shell               52MiB |
|    0   N/A  N/A      3707      G   ...0/usr/lib/firefox/firefox      152MiB |
|    0   N/A  N/A     36453      G   ...RendererForSitePerProcess        8MiB |
+-----------------------------------------------------------------------------+

By the way, no such error occurs with regular matrix multiplication, and the following script distinctively increases GPU utilization:

import jax.numpy as jnp
import jax.random as jrd

rng = jrd.PRNGKey(0)
x = jrd.truncated_normal(rng, -1, 1, (100000,))
kernel = jrd.truncated_normal(rng, -1, 1, (1000, 100000))
y = jnp.matmul(kernel, x)

Ebanflo42 avatar Sep 23 '22 14:09 Ebanflo42

Try an experiment for me: try disabling preallocation, lowering the memory fraction, or using the platform allocator as described here: https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html

Does that help?

hawkinsp avatar Sep 23 '22 14:09 hawkinsp

@hawkinsp Phenomenal! Disabling preallocation solves the issue!

I was aware that these flags were available if too much memory was being preallocated with Xorg (etc) active in the background, but I would have never made the connection with being unable to load cuDNN. What is the root cause here / how did you know to try that?

Ebanflo42 avatar Sep 23 '22 14:09 Ebanflo42

The underlying problem is that CuDNN will fail to initialize if JAX grabs too much GPU memory.

There are two things we can do to improve matters here: a) improve the error message to suggest this as a possible cause, and b) change the JAX allocator design a bit.

(b) is longer term work but (a) we should be able to do sooner.

hawkinsp avatar Sep 23 '22 14:09 hawkinsp