jax
jax copied to clipboard
jax.random.split uses extra memory before preallocated memory is used up
I monitored the gpu memory usage via nvidia-smi
.
I find when run the command
jax.random.split()
Jax will always use more memory even if the preallocated memory is not used at all. This issue keeps raising OOM errors since Jax has already preallocated 90% memory by default.
I have the same problem when running jax.random.split().
jax 0.3.13 jaxlib 0.3.10+cuda11.cudnn82
I have the same problem when running Jax.random.split(), and crashed, the crash detail info is:
RuntimeError: INTERNAL: CustomCall failed: jaxlib/cuda_prng_kernels.cc:30: operation cudaGetLastError() failed: out of memory
WHO CAN HELP ME!!!
My ENV:
1、OS command: cat /proc/versio ----result: Linux version 4.15.0-142-generic (buildd@lgw01-amd64-039) (gcc version 5.4.0 20160609 (Ubuntu 5.4.0-6ubuntu1~16.04.12)) #146~16.04.1-Ubuntu SMP Tue Apr 13 09:27:15 UTC 2021 command: uname -r ----result: 4.15.0-142-generic
2、gcc command: gcc --version ----result: gcc (Ubuntu 5.4.0-6ubuntu1~16.04.12) 5.4.0 20160609
3、cuda && cudnn I have multi cuda on my OS: cuda-9.0、cuda-10.2、cuda-11.3, and the default cuda is ln to cuda-10.2: command: stat /usr/local/cuda ----result: File: '/usr/local/cuda' -> '/usr/local/cuda-10.2/' And I use a virtual environment named 'plen' throw conda, in the plen EVN, the cuda is 11.3: conda env config vars set PATH=/usr/local/cuda-11.3/bin:$PATH -n plen conda env config vars set LD_LIBRARY_PATH=/usr/local/cuda-11.3/lib64:$LD_LIBRARY_PATH -n plen conda env config vars set CUDA_HOME=/usr/local/cuda-11.3 -n plen conda env config vars set CUDA_PATH=/usr/local/cuda-11.3 -n plen conda activate plen command: nvcc -V ----result: Cuda compilation tools, release 11.3, V11.3.109 Build cuda_11.3.r11.3/compiler.29920130_0
command: sudo find /usr -name cudnn_version.h ----result: /usr/include/cudnn_version.h
4、jax && jaxlib JAX: Version: 0.2.26 JAXLIB: 0.1.75+cuda11.cudnn82
5、tensorflow tensorflow: 2.3.1
Can you try lowering the fraction of memory that JAX allocates to something like 0.8? (See https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html)
Does it help?
I have been experiencing exactly the same issue on Ubuntu 22.04, graphics card driver CUDA 11.7, nvcc CUDA version 11.2, cuDNN 8.2.
Doing nothing but just assigning a random key on JAX costs 3900 MB on GPU. And, anytime you split that key, it costs 3900^(n_splits) amount of GPU memory: exponential memory blow up!
It's a similar situation for me as well. Tried tweaking memory allocation from https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html but still getting OOM error
I have the same issue. It does not matter what I set the memory fraction to, this will not run.