jax icon indicating copy to clipboard operation
jax copied to clipboard

Warn if a GPU is detected but it is too old to be supported by XLA

Open clemisch opened this issue 5 years ago • 3 comments

Dear jax team,

I'm struggling with installing jax with GPU support. I'm running Ubuntu 18.04 with CUDA 10.0 and CUDNN 7.6.1. The GPU is a Quadro K4000 with 410.48 drivers (manually installed, no conda).

I tried installing jax from pip and from the repo. For jaxlib, I tried the pip wheels (according to the guide in the readme.md) and compiling from source with

python3 build/build.py --enable_cuda --cuda_path /usr/local/cuda-10.0/ --cudnn_path /usr/local/cuda-10.0/ --enable_march_native

all with no success: jax keeps falling back to CPU. I tried all combinations both with --user installs and global sudo installs. Also with reboots in between.

I made sure I was using the correct jax / jaxlib install every time with jax.__file__ and jaxlib.__file__. I'm out of ideas now. Is there a known problem with Quadro Cards? Could you point me in a direction I have not looked for errors? Thank you very much!

clemisch avatar Jul 08 '19 14:07 clemisch

I've noticed the K4000 GPU is pretty old and only support CUDA compute capability 3.0. Could this be a problem? I noticed no errors in the compilation of jaxlib. If jax needs CUDA compute capability > 3.0 (like tensorflow apparently) there should be an according message when falling back to CPU.

clemisch avatar Jul 08 '19 17:07 clemisch

Yes, I think your GPU is too old. JAX uses XLA for its GPU support, which has the same GPU requirements as TensorFlow: https://www.tensorflow.org/install/gpu#hardware_requirements

I agree that a helpful warning might be a good thing.

hawkinsp avatar Jul 08 '19 18:07 hawkinsp

Thanks for the info @hawkinsp. Not great for me right now but might save others some time if the future :+1:

clemisch avatar Jul 08 '19 18:07 clemisch