jax icon indicating copy to clipboard operation
jax copied to clipboard

jax.random.split uses extra memory before preallocated memory is used up

Open Xin-yang-Liu opened this issue 2 years ago • 5 comments

I monitored the gpu memory usage via nvidia-smi. I find when run the command

jax.random.split()

Jax will always use more memory even if the preallocated memory is not used at all. This issue keeps raising OOM errors since Jax has already preallocated 90% memory by default.

Xin-yang-Liu avatar May 27 '22 21:05 Xin-yang-Liu

I have the same problem when running jax.random.split().

jax 0.3.13 jaxlib 0.3.10+cuda11.cudnn82

yueyang130 avatar Jul 16 '22 04:07 yueyang130

I have the same problem when running Jax.random.split(), and crashed, the crash detail info is:

RuntimeError: INTERNAL: CustomCall failed: jaxlib/cuda_prng_kernels.cc:30: operation cudaGetLastError() failed: out of memory

WHO CAN HELP ME!!!

image

My ENV:

1、OS command: cat /proc/versio ----result: Linux version 4.15.0-142-generic (buildd@lgw01-amd64-039) (gcc version 5.4.0 20160609 (Ubuntu 5.4.0-6ubuntu1~16.04.12)) #146~16.04.1-Ubuntu SMP Tue Apr 13 09:27:15 UTC 2021 command: uname -r ----result: 4.15.0-142-generic

2、gcc command: gcc --version ----result: gcc (Ubuntu 5.4.0-6ubuntu1~16.04.12) 5.4.0 20160609

3、cuda && cudnn I have multi cuda on my OS: cuda-9.0、cuda-10.2、cuda-11.3, and the default cuda is ln to cuda-10.2: command: stat /usr/local/cuda ----result: File: '/usr/local/cuda' -> '/usr/local/cuda-10.2/' And I use a virtual environment named 'plen' throw conda, in the plen EVN, the cuda is 11.3: conda env config vars set PATH=/usr/local/cuda-11.3/bin:$PATH -n plen conda env config vars set LD_LIBRARY_PATH=/usr/local/cuda-11.3/lib64:$LD_LIBRARY_PATH -n plen conda env config vars set CUDA_HOME=/usr/local/cuda-11.3 -n plen conda env config vars set CUDA_PATH=/usr/local/cuda-11.3 -n plen conda activate plen command: nvcc -V ----result: Cuda compilation tools, release 11.3, V11.3.109 Build cuda_11.3.r11.3/compiler.29920130_0 image command: sudo find /usr -name cudnn_version.h ----result: /usr/include/cudnn_version.h image

4、jax && jaxlib JAX: Version: 0.2.26 JAXLIB: 0.1.75+cuda11.cudnn82

5、tensorflow tensorflow: 2.3.1

pylon2008 avatar Jul 26 '22 11:07 pylon2008

Can you try lowering the fraction of memory that JAX allocates to something like 0.8? (See https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html)

Does it help?

hawkinsp avatar Jul 27 '22 20:07 hawkinsp

I have been experiencing exactly the same issue on Ubuntu 22.04, graphics card driver CUDA 11.7, nvcc CUDA version 11.2, cuDNN 8.2.

Doing nothing but just assigning a random key on JAX costs 3900 MB on GPU. And, anytime you split that key, it costs 3900^(n_splits) amount of GPU memory: exponential memory blow up!

image

onurdanaci avatar Aug 31 '22 16:08 onurdanaci

It's a similar situation for me as well. Tried tweaking memory allocation from https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html but still getting OOM error

jaymehta1212 avatar Sep 06 '22 19:09 jaymehta1212

I have the same issue. It does not matter what I set the memory fraction to, this will not run.

SamTov avatar Dec 02 '22 15:12 SamTov