jax
jax copied to clipboard
Installing JAX on Arch
Installing JAX on Arch has been surprisingly difficult. I've been trying to (re)install it for several hours now, after accidentally updating CUDA. While I got it working last time (with CUDA 11.0), what I did then doesn't work now.
Steps taken to install
I'm working with the AUR repository using yay
:
yay cuda
yay cudnn
It reports the packages are successfully installed, with these versions:
cuda-11.2.1-2
cudnn-8.1.0.77-1
pip
is already fully upgraded (21.0.11
), so now for JAX:
$ sudo pip install --upgrade --force jax jaxlib==0.1.62+cuda112 -f https://storage.googleapis.com/jax-releases/jax_releases.html
Successfully installed jax-0.2.10 jaxlib-0.1.62+cuda112
Because JAX expects CUDA at /usr/local/cuda-XX.X
, but Arch installs CUDA at /opt/cuda
, I create a symbolic link:
sudo ln -s /opt/cuda /usr/local/cuda-11.2
Checking installation
Just as a sanity check, I see if JAX can access devices in a Python shell, which it can:
$ python
>> import jax
>> jax.devices()
[GpuDevice(id=0)]
I then try to run the following example, from Convolutions in JAX:
from jax import numpy as jnp, random, lax
key = random.PRNGKey(1701)
kernel = jnp.zeros((3, 3, 3, 3))
kernel += jnp.array([[1, 1, 0],
[1, 0,-1],
[0,-1,1]])[:, :, jnp.newaxis, jnp.newaxis]
img = jnp.zeros((1, 200, 198, 3))
for k in range(3):
x = 30 + 60 * k
y = 20 + 60 * k
img = img.at[0, x:x+10, y:y+10, k].set(1)
out = lax.conv(jnp.transpose(img, [0,3,1,2]),
jnp.transpose(kernel, [3,2,0,1]),
(1, 1),
'SAME')
I get the following errors.
2021-03-12 04:30:44.451633: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc:780] Failed to determine best cudnn convolution algorithm: Internal: All algorithms tried for convolution %custom-call = (f32[1,3,200,198]{3,2,1,0}, u8[0]{0}) custom-call(f32[1,3,200,198]{3,2,1,0} %parameter.1, f32[3,3,3,3]{3,2,1,0} %parameter.2), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", metadata={op_type="conv_general_dilated" op_name="conv_general_dilated[ batch_group_count=1\n dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 1, 2, 3), rhs_spec=(0, 1, 2, 3), out_spec=(0, 1, 2, 3))\n feature_group_count=1\n lhs_dilation=(1, 1)\n lhs_shape=(1, 3, 200, 198)\n padding=((1, 1), (1, 1))\n precision=None\n rhs_dilation=(1, 1)\n rhs_shape=(3, 3, 3, 3)\n window_strides=(1, 1) ]"}, backend_config="{\"algorithm\":\"0\",\"tensor_ops_enabled\":false,\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}" failed. Falling back to default algorithm.
Convolution performance may be suboptimal.
2021-03-12 04:30:44.548471: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:1881] Execution of replica 0 failed: Unknown: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(3294): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd.handle(), input_data.opaque(), filter_nd.handle(), filter_data.opaque(), conv.handle(), ToConvForwardAlgo(algorithm_desc), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd.handle(), output_data.opaque())'
Traceback (most recent call last):
File "/home/kuhlig/Documents/Programming/convolutional-deconvolution/test.py", line 17, in <module>
out = lax.conv(jnp.transpose(img, [0,3,1,2]),
File "/usr/lib/python3.9/site-packages/jax/_src/lax/lax.py", line 1582, in conv
return conv_general_dilated(lhs, rhs, window_strides, padding,
File "/usr/lib/python3.9/site-packages/jax/_src/lax/lax.py", line 600, in conv_general_dilated
return conv_general_dilated_p.bind(
File "/usr/lib/python3.9/site-packages/jax/core.py", line 284, in bind
out = top_trace.process_primitive(self, tracers, params)
File "/usr/lib/python3.9/site-packages/jax/core.py", line 622, in process_primitive
return primitive.impl(*tracers, **params)
File "/usr/lib/python3.9/site-packages/jax/interpreters/xla.py", line 242, in apply_primitive
return compiled_fun(*args)
File "/usr/lib/python3.9/site-packages/jax/interpreters/xla.py", line 360, in _execute_compiled_primitive
out_bufs = compiled.execute(input_bufs)
RuntimeError: Unknown: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(3294): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd.handle(), input_data.opaque(), filter_nd.handle(), filter_data.opaque(), conv.handle(), ToConvForwardAlgo(algorithm_desc), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd.handle(), output_data.opaque())'
I've read through everything I can, and the only suggestion I can find is that jaxlib
, cuda
or cudnn
versions must mismatch. Unfortunately, they don't seem to:
$ nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2021 NVIDIA Corporation
Built on Thu_Jan_28_19:32:09_PST_2021
Cuda compilation tools, release 11.2, V11.2.142
Build cuda_11.2.r11.2/compiler.29558016_0
$ whereis cudnn_version
cudnn_version: /usr/include/cudnn_version.h
$ cat /usr/include/cudnn_version.h
...
#define CUDNN_MAJOR 8
#define CUDNN_MINOR 1
#define CUDNN_PATCHLEVEL 0
...
What I did last time (or thought I did – something got it working, and I might be attributing it to the wrong thing), was create a symbolic link to the cudnn
files in /usr/local/cuda-11.2/include
and /usr/local/cuda-11.2/lib64
, as follows:
sudo ln -s /usr/include/cudnn*.h /usr/local/cuda-11.2/include
sudo ln -s /usr/lib64/libcudnn*.so /usr/lib64/libcudnn_static.a /usr/local/cuda-11.2/lib64
This unfortunately, does not seem to change anything, so I might just be barking up the wrong tree. Any help?
I also just compiled jaxlib from source (pointing it at my CUDA and CUDNN installations) successfully:
python build/build.py --enable_cuda --cuda_path /usr/local/cuda-11.2 --cudnn_path /usr/ --cuda_version 11.2 --cudnn_version 8
I then made a virtual environment with venv
and installed jaxlib and JAX using the produced wheel, but the same error remains.
Running pytest -n auto tests
for about 20 minutes resulted in a lot of failed tests, and a few errors.
This might be an issue with jaxlib+cuda112 wheel (I admit I only tested the cuda 11.0 version!). I can try it out later, but to any passersby on cuda 11.2, have you successfully used the jaxlib+cuda112 wheel?
Installing JAX on Arch has been surprisingly difficult
I thought difficult installations was the whole reason people used Arch! One time in grad school my X11 setup on Arch was broken for months due to a pacman update, so I just learned to work without a graphical interface (until I finally gave up and re-imaged the machine). At least I had the bleeding-edge version of wget though.
Maybe the lesson here is that if JAX installation is hard for an Arch user, it must be really hard... 😄
I just tried your example with CUDA 11.2 and jaxlib 0.1.62 and it works for me. What kind of GPU do you have? I'll also ask the XLA:GPU team if they have any thoughts on what might cause that error.
I have two graphics cards (one integrated and one discrete; currently the integrated one shouldn't be used in any of the following, as verified by running nvidia-smi
while using JAX). The discrete one is an nVidia GeForce MX150 (GP108M), using the video-hybrid-intel-nvidia-prime driver.
Here's the output of nsvidia-smi
while running the example I provided for more (slightly more) details:
$ nvidia-smi
Fri Mar 12 18:23:40 2021
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.56 Driver Version: 460.56 CUDA Version: 11.2 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 GeForce MX150 Off | 00000000:02:00.0 Off | N/A |
| N/A 45C P3 N/A / N/A | 1999MiB / 2002MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| 0 N/A N/A 486 G /usr/lib/Xorg 4MiB |
| 0 N/A N/A 94675 C python 1993MiB |
+-----------------------------------------------------------------------------+
Also, @matjj, if you're not losing the graphical interface at least once every few months, you're not having fun :)
I've tried looking for any other files that might be left over from past installations/other libraries like PyTorch. I removed all those (even though I know PyTorch, for example, is shipped with its own CUDA binaries, and logically couldn't interfere with JAX), to try with as clean of a slate as I could get, then reinstalled.
I'm still seeing the same error, which is really frustrating, given that JAX worked a couple weeks ago with CUDA 11.0. Everything is pointing towards this not being a problem with JAX, but something else in my environment. Does anyone have ideas of what else I could check? What could be messing with JAX?
Do you have a symlink /usr/local/cuda
(no version number)? If not, try adding one?
@hawkinsp I get the same error, even after adding a symlink from /usr/local/cuda
to /opt/cuda
Not sure if this is still an issue but I just managed to upgrade from cuda 11.0
to cuda 11.3
and added the new cuda 11.1
jax wheels, without much effort, on arch. Maybe you can give it another try?
Those are the steps I followed, if it can help somebody.
- Remove old cuda and cudnn
sudo pacman -Rns cudnn8-cuda11.0 cuda-11.0
- Install latest cuda and cudnn and update package list
sudo pacman -Su cuda cudnn
To test that all went fine:
> nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2021 NVIDIA Corporation
Built on Sun_Mar_21_19:15:46_PDT_2021
Cuda compilation tools, release 11.3, V11.3.58
Build cuda_11.3.r11.3/compiler.29745058_0
- Create symlink
sudo ln -s /opt/cuda /usr/local/cuda-11.3
- Update jax
pip install --upgrade jax jaxlib==0.1.66+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
Testing it out:
> cat test.py
File: test.py
from jax import numpy as jnp, random, lax
key = random.PRNGKey(1701)
kernel = jnp.zeros((3, 3, 3, 3))
kernel += jnp.array([[1, 1, 0],
[1, 0,-1],
[0,-1,1]])[:, :, jnp.newaxis, jnp.newaxis]
img = jnp.zeros((1, 200, 198, 3))
for k in range(3):
x = 30 + 60 * k
y = 20 + 60 * k
img = img.at[0, x:x+10, y:y+10, k].set(1)
out = lax.conv(jnp.transpose(img, [0,3,1,2]),
jnp.transpose(kernel, [3,2,0,1]),
(1, 1),
'SAME')
> python test.py
Hi, sorry for the delay. Is this still an issue for anyone?
@skye I'm still experiencing the same exact errors for the provided sample (and anything involving convolution), but literally everything else in JAX works. I'm currently using Flax for another project without issue – I just can't use convolution. My installation is:
cudnn8-cuda11.0
cuda11.0
jaxlib==0.1.67+cuda110
I'll admit I haven't tried updating to the newest versions of cuda and cudnn in a while, so I'll try one more time
@Numeri Where is libcudnn.so.8
located on your system?
@hawkinsp /opt/cuda/targets/x86_64-linux/lib/libcudnn.so.8
@Numeri Can you try symlinking /opt/cuda
to /usr/local/cuda
?
Another option might be to add /opt/cuda/targets/x86_64-linux/lib
to your LD_LIBRARY_PATH
.
@hawkinsp It is symlinked to both /usr/local/cuda
and /usr/local/cuda-11.0
, but I can try adding it to my LD_LIBRARY_PATH
.
I also just tried using the exact versions that @astanziola used, and have the same issue.
Edit: I have the same error after adding export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/cuda/targets/x86_64-linux/lib
to my .bashrc
and source
-ing it.
Can you share the complete log with TF_CPP_MIN_LOG_LEVEL=0
when you run a convolution? The log at the top of the bug is missing a few things, I'm hoping maybe something interesting is in the logs.
I have another thing to try. Can you try with XLA_PYTHON_CLIENT_MEM_FRACTION=.5
or XLA_PYTHON_CLIENT_ALLOCATOR=platform
?
So I've tried both of those before – I can't remember the exact output, but I know they didn't work/just delayed this issue temporarily. I can try those and the logging flag once I'm back to my laptop.
Peter Hawkins @.***> schrieb am Di., 8. Juni 2021, 21:26:
I have another thing to try. Can you try with XLA_PYTHON_CLIENT_MEM_FRACTION=.5 or XLA_PYTHON_CLIENT_ALLOCATOR=platform?
— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/google/jax/issues/6039#issuecomment-857342830, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABQ3TGHVM5ASHCARWAQ2YVLTR3NPLANCNFSM4ZBZNHTQ .
Can you share the complete log with
TF_CPP_MIN_LOG_LEVEL=0
when you run a convolution? The log at the top of the bug is missing a few things, I'm hoping maybe something interesting is in the logs.
numeri ~ $ TF_CPP_MIN_LOG_LEVEL=0 python3 /tmp/test.py
2021-06-09 01:19:41.611304: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:169] XLA service 0x55d593bf72d0 initialized for platform Interpreter (this does not guarantee that XLA will be used). Devices:
2021-06-09 01:19:41.611338: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:177] StreamExecutor device (0): Interpreter, <undefined>
2021-06-09 01:19:41.631288: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:169] XLA service 0x55d593beb7d0 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2021-06-09 01:19:41.631359: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:177] StreamExecutor device (0): Host, Default Version
2021-06-09 01:19:41.822625: I external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-06-09 01:19:41.823158: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:169] XLA service 0x55d593c13a80 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2021-06-09 01:19:41.823191: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:177] StreamExecutor device (0): NVIDIA GeForce MX150, Compute Capability 6.1
2021-06-09 01:19:41.823585: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/gpu_device.cc:298] Using BFC allocator.
2021-06-09 01:19:41.823648: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/gpu_device.cc:257] XLA backend allocating 1854288691 bytes on device 0 for BFCAllocator.
2021-06-09 01:19:41.824216: I external/org_tensorflow/tensorflow/stream_executor/tpu/tpu_platform_interface.cc:74] No TPU platform found.
2021-06-09 01:19:44.491876: I external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc:368] Loaded cuDNN version 8100
2021-06-09 01:19:44.837236: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc:780] Failed to determine best cudnn convolution algorithm: Internal: All algorithms tried for convolution %custom-call = (f32[1,3,200,198]{3,2,1,0}, u8[0]{0}) custom-call(f32[1,3,200,198]{3,2,1,0} %parameter.1, f32[3,3,3,3]{3,2,1,0} %parameter.2), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", metadata={op_type="conv_general_dilated" op_name="conv_general_dilated[ batch_group_count=1\n dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 1, 2, 3), rhs_spec=(0, 1, 2, 3), out_spec=(0, 1, 2, 3))\n feature_group_count=1\n lhs_dilation=(1, 1)\n lhs_shape=(1, 3, 200, 198)\n padding=((1, 1), (1, 1))\n precision=None\n preferred_element_type=None\n rhs_dilation=(1, 1)\n rhs_shape=(3, 3, 3, 3)\n window_strides=(1, 1) ]"}, backend_config="{\"algorithm\":\"0\",\"tensor_ops_enabled\":false,\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}" failed. Falling back to default algorithm.
Convolution performance may be suboptimal.
2021-06-09 01:19:44.840458: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:1981] Execution of replica 0 failed: Unknown: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(3910): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd.handle(), input_data.opaque(), filter_nd.handle(), filter_data.opaque(), conv.handle(), ToConvForwardAlgo(algorithm_desc), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd.handle(), output_data.opaque())'
Traceback (most recent call last):
File "/tmp/test.py", line 17, in <module>
out = lax.conv(jnp.transpose(img, [0,3,1,2]),
File "/home/numeri/.local/lib/python3.9/site-packages/jax/_src/lax/lax.py", line 1690, in conv
return conv_general_dilated(lhs, rhs, window_strides, padding,
File "/home/numeri/.local/lib/python3.9/site-packages/jax/_src/lax/lax.py", line 625, in conv_general_dilated
return conv_general_dilated_p.bind(
File "/home/numeri/.local/lib/python3.9/site-packages/jax/core.py", line 264, in bind
out = top_trace.process_primitive(self, tracers, params)
File "/home/numeri/.local/lib/python3.9/site-packages/jax/core.py", line 606, in process_primitive
return primitive.impl(*tracers, **params)
File "/home/numeri/.local/lib/python3.9/site-packages/jax/interpreters/xla.py", line 232, in apply_primitive
return compiled_fun(*args)
File "/home/numeri/.local/lib/python3.9/site-packages/jax/interpreters/xla.py", line 350, in _execute_compiled_primitive
out_bufs = compiled.execute(input_bufs)
RuntimeError: Unknown: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(3910): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd.handle(), input_data.opaque(), filter_nd.handle(), filter_data.opaque(), conv.handle(), ToConvForwardAlgo(algorithm_desc), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd.handle(), output_data.opaque())'
I've also tried XLA_PYTHON_CLIENT_MEM_FRACTION=.5
and XLA_PYTHON_CLIENT_ALLOCATOR=platform
, and they make this example work just fine, but I'm still seeing errors when using convolutions in a model – it's trying to allocate 1.25 GiB on my 2 GiB GPU and failing, when the biggest convolutional layer I'm using is 32x32 with 256 features – a size that runs fine with PyTorch.
@Numeri Ok, I think that resolves the issue. This is nothing to do with arch: everything is working fine on Arch as you've set it up. The issue is that your GPU doesn't have very much memory. Both CuDNN and JAX need some memory to work, and by default JAX allocates too much. See: https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html for more details.
We might be able to tweak the defaults to make things work a little better on low-memory configurations, but it's a niche use case (2GB is pretty small for a current GPU).
@hawkinsp That's fair, I just thought it was a bug because it was only happening to me with JAX, and not with PyTorch (I can use much, much larger models with PyTorch, but the difference is only there for convolution layers). I assumed that it was my CUDA/cuDNN install, since PyTorch comes with a prepackaged install, as far as I understand. Personally, this still feels like a bug, but you're definitely right that it has nothing to do with Arch.
For those that get here via a Google search like I did, and who, while skimming, missed the solution that hawkinsp gave, the solution is to prevent JAX from preallocating too much memory by setting the XLA_PYTHON_CLIENT_MEM_FRACTION
environment variable to something lower than 0.9:
$ export XLA_PYTHON_CLIENT_MEM_FRACTION=.7
Or:
>>> import os
>>> os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".7"
My GPU has 8GB of memory (about 6.5GB after OS takes its share), which is still small by normal machine-learning standards, but this still seems more like a bug than a user error. I wonder if the situation here could be improved @hawkinsp? Even just a warning for GPUs that have <10GB of memory, informing the user that they may need to set XLA_PYTHON_CLIENT_MEM_FRACTION
to a lower amount to prevent OOM errors?
@Numeri is this resolved?
@sudhakarsingh27 I think it's unfortunate that JAX uses so much more memory than PyTorch does for the same exact layers, but it doesn't seem like there is an easy fix, so I think we can close this, yes.