jax icon indicating copy to clipboard operation
jax copied to clipboard

Failed to determine best cudnn convolution algorithm/No GPU/TPU found

Open iwldzt3011 opened this issue 3 years ago • 8 comments

RTX3080 / cuda11.1/cudnn 8.2.1/ubuntu16.04

This problem occurs in jaxlib-0.1.72+cuda111. When I update to 0.1.74, it will disappear. However, in 0.1.74, Jax cannot detect the existence of GPU, and tensorflow can

Therefore, whether I use 0.1.72 or 0.1.74, there is always a problem with me

`RuntimeError: UNKNOWN: Failed to determine best cudnn convolution algorithm: INTERNAL: All algorithms tried for %custom-call.1 = (f32[1,112,112,64]{2,1,3,0}, u8[0]{0}) custom-call(f32[1,229,229,3]{2,1,3,0} %pad, f32[7,7,3,64]{1,0,2,3} %copy.4), window={size=7x7 stride=2x2}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convForward", metadata={op_type="conv_general_dilated" op_name="jit(conv_general_dilated)/conv_general_dilated[\n batch_group_count=1\n dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n feature_group_count=1\n lhs_dilation=(1, 1)\n lhs_shape=(1, 224, 224, 3)\n padding=((2, 3), (2, 3))\n precision=None\n preferred_element_type=None\n rhs_dilation=(1, 1)\n rhs_shape=(7, 7, 3, 64)\n window_strides=(2, 2)\n]" source_file="/media/node/Materials/anaconda3/envs/xmcgan/lib/python3.9/site-packages/flax/linen/linear.py" source_line=282}, backend_config="{"algorithm":"0","tensor_ops_enabled":false,"conv_result_scale":1,"activation_mode":"0","side_input_scale":0}" failed. Falling back to default algorithm.

Convolution performance may be suboptimal. To ignore this failure and try to use a fallback algorithm, use XLA_FLAGS=--xla_gpu_strict_conv_algorithm_picker=false. Please also file a bug for the root cause of failing autotuning. `

iwldzt3011 avatar Dec 01 '21 05:12 iwldzt3011

What version of the jaxlib 0.1.74 wheel did you install, and how did you install it? Try removing jaxlib and reinstalling it following the instructions here: https://github.com/google/jax#pip-installation-gpu-cuda ?

hawkinsp avatar Dec 01 '21 14:12 hawkinsp

What version of the jaxlib 0.1.74 wheel did you install, and how did you install it? Try removing jaxlib and reinstalling it following the instructions here: https://github.com/google/jax#pip-installation-gpu-cuda ?

I use a stand-alone version of jaxlib = 0.1.74, that is: pip install jaxlib

Because the latest version of jaxlib combined with cuda111 in this link [https://storage.googleapis.com/jax-releases/jax_releases.html] is still 0.1.72, that is, jaxlib 0.1.72 + cuda111, I can't get jaxlib 0.1.74 + cud111 from it

However, there is jaxlib 0.1.74 + cud11 in the above link, so I also try to use jaxlib 0.1.74 + cud11, but unfortunately, this version has the same error as jaxlib 0.1.72 + cuda111

iwldzt3011 avatar Dec 02 '21 05:12 iwldzt3011

Do you fix the error ?

ross-Hr avatar Jan 04 '22 03:01 ross-Hr

I have faced this issue recently,I run my the jax:How to think in Jax documation, and Jupyter notebook report this error when i try to do convolve.Besides, this error somtimes disappear and i do not know why.

dljjqy avatar Feb 20 '22 09:02 dljjqy

I have this exact same issue when trying to run https://github.com/google/mipnerf. I get a failed to determine best cudnn convolution algorithm when running jax.scipy.signal.convolve2d. I only get the error when running their code base and not when trying to run the convolve operation itself. It seems related to running vmap on convolve2d and is related to the version of cuda + cudnn being used.

cuda 11.5 cudnn 8.3.2 jax 0.3.2

half-potato avatar Mar 16 '22 20:03 half-potato

Turns out it was an OOM error, just a bad error message. Solution is in #8506. use the environment flag XLA_PYTHON_CLIENT_MEM_FRACTION=0.87. It appears that there is some kind of issue with how jax.scipy.signal.convolve2d handles preallocated memory. I believe it would be nice to have a better error message for this.

half-potato avatar Mar 16 '22 20:03 half-potato

I have the same error on my Titan RTX which is based on Turing architecture. After some trail and errors, I find the error may be related with cudnn version. If I export the LD_LIBRARY_PATH with cudnn 8.2.1, it works. cudnn 8.2.4 could not work.

luweizheng avatar Jul 22 '22 07:07 luweizheng

Was this issue resolved? @iwldzt3011

sudhakarsingh27 avatar Aug 08 '22 20:08 sudhakarsingh27

closing since no activity/no add. info provided.

sudhakarsingh27 avatar Aug 24 '22 19:08 sudhakarsingh27

(I should add: if someone can provide instructions to reproduce the problem, e.g., on a cloud GPU VM or similar, we would love to look into it further!)

hawkinsp avatar Aug 24 '22 19:08 hawkinsp

Hello all.

I don't have a GPU VM, but can confirm I have the same problem with a EVGA 3070ti XC3. What may help to pin the problem is that I installed the conda recipe using:

conda install jax cuda-nvcc -c conda-forge -c nvidia

The nvcc version info in the conda environment reads as follows:

$ nvcc -V
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Wed_Jun__8_16:49:14_PDT_2022
Cuda compilation tools, release 11.7, V11.7.99
Build cuda_11.7.r11.7/compiler.31442593_0

I found the CuDNN version in the include folder in the virtual env:

#define CUDNN_MAJOR 8
#define CUDNN_MINOR 4
#define CUDNN_PATCHLEVEL 1

Any chance this helps to reproduce the problem? If you have a temporary work around, I'd love to try that.

hcwinsemius avatar Aug 27 '22 10:08 hcwinsemius

I did the following and it works

export XLA_PYTHON_CLIENT_PREALLOCATE=false
export XLA_FLAGS="--xla_gpu_strict_conv_algorithm_picker=false --xla_gpu_force_compilation_parallelism=1"

amughrabi avatar Nov 25 '22 22:11 amughrabi

I did the following and it works

export XLA_PYTHON_CLIENT_PREALLOCATE=false
export XLA_FLAGS="--xla_gpu_strict_conv_algorithm_picker=false --xla_gpu_force_compilation_parallelism=1"

Thanks, this works!

BeaverInGreenland avatar Jan 10 '24 11:01 BeaverInGreenland