jax
jax copied to clipboard
Warn if a GPU is detected but it is too old to be supported by XLA
Dear jax
team,
I'm struggling with installing jax
with GPU support. I'm running Ubuntu 18.04 with CUDA 10.0 and CUDNN 7.6.1. The GPU is a Quadro K4000 with 410.48 drivers (manually installed, no conda
).
I tried installing jax
from pip and from the repo. For jaxlib
, I tried the pip wheels (according to the guide in the readme.md) and compiling from source with
python3 build/build.py --enable_cuda --cuda_path /usr/local/cuda-10.0/ --cudnn_path /usr/local/cuda-10.0/ --enable_march_native
all with no success: jax
keeps falling back to CPU. I tried all combinations both with --user
installs and global sudo
installs. Also with reboots in between.
I made sure I was using the correct jax
/ jaxlib
install every time with jax.__file__
and jaxlib.__file__
. I'm out of ideas now. Is there a known problem with Quadro Cards? Could you point me in a direction I have not looked for errors? Thank you very much!
I've noticed the K4000 GPU is pretty old and only support CUDA compute capability 3.0. Could this be a problem? I noticed no errors in the compilation of jaxlib
. If jax
needs CUDA compute capability > 3.0 (like tensorflow apparently) there should be an according message when falling back to CPU.
Yes, I think your GPU is too old. JAX uses XLA for its GPU support, which has the same GPU requirements as TensorFlow: https://www.tensorflow.org/install/gpu#hardware_requirements
I agree that a helpful warning might be a good thing.
Thanks for the info @hawkinsp. Not great for me right now but might save others some time if the future :+1: