jax icon indicating copy to clipboard operation
jax copied to clipboard

Jax leaks memory with random keys

Open smorad opened this issue 11 months ago • 12 comments

Description

I'm able to leak memory on CPU by running the following script

from jax import random
import psutil

key = random.PRNGKey(0)
key, netkey = random.split(key)

iters = 500_000
key, *data_keys = random.split(key, iters + 1)
for i in range(iters):
    inputs = random.normal(data_keys[i], (10,))
    if i % 1000 == 0:
        print(f"{psutil.Process().memory_info().rss / 1024 ** 2:.1f} MB")

When run, it prints:

910.5 MB
910.7 MB
910.8 MB
911.0 MB
911.3 MB
911.5 MB
911.7 MB
911.9 MB
912.1 MB
912.3 MB
912.5 MB
912.7 MB
912.9 MB
913.2 MB
...

This also uses way more memory than I was expecting. Why does the first run require nearly 1GB? 500,000 random keys at 8 bytes each should take up less than 4MB, right?

What jax/jaxlib version are you using?

jax-0.4.14, jaxlib-0.4.14

Which accelerator(s) are you using?

CPU

Additional system info

python-3.11.4, MacOS

NVIDIA GPU info

No response

smorad avatar Sep 05 '23 10:09 smorad

This looks like a garbage collection issue.

In each iteration, once inputs goes out of scope, it doesn't immediately get deleted. Rather, its CPython reference count goes to zero, and sometime later the garbage collector will delete the variable.

Once the garbage collector deletes the Python object, it triggers a call to XLA to clear the allocated buffer from the device. This delete call is also asynchronous, and will happen some time later in the program.

If you're concerned about arrays/buffers from previous loops being deleted more quickly during your program execution, one thing you can do is to explicitly delete or garbage collect, though this will slow the execution of your code.

For example:

for i in range(iters):
    inputs = random.normal(data_keys[i], (10,))
    if i % 1000 == 0:
        print(f"{psutil.Process().memory_info().rss / 1024 ** 2:.1f} MB")
        gc.collect()  # Python garbage collection: this is slow, so avoid doing it every iteration.
    inputs.delete() # JAX/XLA buffer deletion. This is fast
                    # (it's an asynchronous call) so we can do it every iteration.
    del inputs

You should see the memory use greatly reduced.

jakevdp avatar Sep 05 '23 13:09 jakevdp

Interestingly, this doesn't grow:

from jax import random
import psutil

key = random.PRNGKey(0)
key, netkey = random.split(key)

iters = 500_000
data_keys = random.split(key, iters)  # notice no tuple being formed here
for i in range(iters):
    inputs = random.normal(data_keys[i], (10,))
    if i % 1000 == 0:
        print(f"{psutil.Process().memory_info().rss / 1024 ** 2:.1f} MB")

But this does (one line added):

from jax import random
import psutil

key = random.PRNGKey(0)
key, netkey = random.split(key)

iters = 500_000
data_keys = random.split(key, iters)
data_keys = tuple(data_keys)  # NOTE NOTE NOTE
for i in range(iters):
    inputs = random.normal(data_keys[i], (10,))
    if i % 1000 == 0:
        print(f"{psutil.Process().memory_info().rss / 1024 ** 2:.1f} MB")

mattjj avatar Sep 05 '23 18:09 mattjj

Ah I think @hawkinsp figured it out. This also doesn't leak:

from jax import random
import psutil

key = random.PRNGKey(0)
key, netkey = random.split(key)

iters = 500_000
data_keys = random.split(key, iters)
data_keys = list(data_keys)  # NOTE list not tuple
for i in range(iters):
    inputs = random.normal(data_keys.pop(), (10,))  # NOTE pop
    if i % 1000 == 0:
        print(f"{psutil.Process().memory_info().rss / 1024 ** 2:.1f} MB")

The issues are:

  1. the program uses more memory than you might expect because key, *data_keys = random.split(key, iters + 1) is unpacking a single array into 50,001 separate arrays (e.g. the 50,000 components of the data_keys tuple);
  2. the memory usage grows because each jax.Array has cached properties on it, and so each of the 50,000 arrays has that cached metadata populated when it's touched (i.e. when it's read by the ranodm.normal(data_keys[i], (10,)) expression)

I'm not sure if there's anything to fix here, at least in JAX itself, though I could be wrong. If you have a real program with this kind of memory issue, consider not unpacking 50,000 separate arrays and instead writing something more like

from jax import random
import psutil

key = random.PRNGKey(0)
key, netkey = random.split(key)

iters = 500_000
keys = random.split(key, iters + 1)
key = keys[0]
data_keys = keys[1:]  # note: don't unpack data_keys into a tuple
for i in range(iters):
    inputs = random.normal(data_keys[i], (10,))
    if i % 1000 == 0:
        print(f"{psutil.Process().memory_info().rss / 1024 ** 2:.1f} MB")

mattjj avatar Sep 05 '23 19:09 mattjj

Actually @hawkinsp and @yashk2810 have ideas for how we can reduce the cached stuff here... looking into it!

mattjj avatar Sep 05 '23 19:09 mattjj

@mattjj thanks for the tip, I was unaware of the array caching mechanics (hidden side effect!). I figured precomputing all the keys in one batched operation would be more efficient than doing it one at a time in the loop. The array slice seems like a good alternative.

smorad avatar Sep 05 '23 21:09 smorad

https://github.com/google/jax/pull/17452 should fix the leak problem.

yashk2810 avatar Sep 05 '23 23:09 yashk2810

The PR has been submitted. Can you try again?

yashk2810 avatar Sep 06 '23 00:09 yashk2810

Sorry I'm not super familiar with your PR process. https://github.com/google/jax/pull/17452 does not have any commits attached to it. I just cloned and tested on master and I'm still seeing the memory leak, even with the call to gc.collect() (albeit it leaks much more slowly).

smorad avatar Sep 06 '23 13:09 smorad

The script I ran, for posterity:

from jax import random
import psutil
import gc

key = random.PRNGKey(0)

iters = 500_000
key, *data_keys = random.split(key, iters + 1)
for i in range(iters):
    inputs = random.normal(data_keys[i], (10,))
    if i % 1000 == 0:
        print(f"{psutil.Process().memory_info().rss / 1024 ** 2:.1f} MB")
        gc.collect()
909.8 MB
909.8 MB
909.8 MB
909.8 MB
909.8 MB
909.8 MB
909.8 MB
909.8 MB
909.8 MB
909.8 MB
909.8 MB
909.8 MB
909.8 MB
909.8 MB
909.8 MB
909.8 MB
909.8 MB
909.9 MB
909.9 MB
909.9 MB
909.9 MB
909.9 MB
909.9 MB
909.9 MB
910.0 MB

smorad avatar Sep 06 '23 13:09 smorad

I can't seem to repro it. Are you sure you installed jax at HEAD? Note you would need to clone jax and then cd jax; pip install -U .. Maybe also uninstall jax before you do this?

Also note that the memory usage will go up and down which is expected.

   ...: iters = 500_000
   ...: key, *data_keys = random.split(key, iters + 1)
   ...: for i in range(iters):
   ...:     inputs = random.normal(data_keys[i], (10,))
   ...:     if i % 1000 == 0:
   ...:         print(f"{psutil.Process().memory_info().rss / 1024 ** 2:.1f} MB")
   ...:         gc.collect()
   ...:
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
1011.9 MB
1012.5 MB
1012.5 MB
1012.5 MB
1012.5 MB
1012.5 MB
1012.5 MB
1012.5 MB
1012.5 MB
1012.5 MB
1012.5 MB
1012.5 MB
1012.5 MB
1012.5 MB
1012.5 MB
1012.5 MB
1012.5 MB
1012.5 MB
1012.5 MB
1012.5 MB
1012.5 MB
1012.5 MB
1012.5 MB
1012.5 MB
1012.5 MB
1012.5 MB
1012.5 MB
1012.5 MB
1012.5 MB
1012.5 MB
1012.5 MB
1012.5 MB
1012.5 MB
1012.5 MB
1012.5 MB
1012.5 MB
1012.5 MB
1012.5 MB
1012.5 MB
1012.5 MB
1012.5 MB
1012.5 MB
1012.5 MB
1012.5 MB
1012.5 MB
1012.5 MB
1012.5 MB
1012.5 MB
1012.5 MB
1012.5 MB
1012.5 MB
1012.5 MB
1012.5 MB
1012.5 MB
1012.5 MB
1012.5 MB
1012.5 MB
1012.5 MB
1012.5 MB

yashk2810 avatar Sep 06 '23 15:09 yashk2810

Weird, maybe it's something to do with my python/macos version.

pip uninstall jax jaxlib
pip install git+https://github.com/google/jax jaxlib # Successfully installed jax-0.4.16.dev20230906 jaxlib-0.4.14
python debug.py # The listed script
907.4 MB
907.4 MB
907.4 MB
907.4 MB
907.4 MB
907.4 MB
907.4 MB
907.4 MB
907.4 MB
907.4 MB
907.4 MB
907.4 MB
907.4 MB
907.4 MB
907.4 MB
907.4 MB
907.5 MB
907.5 MB
907.5 MB
907.6 MB
907.6 MB
907.6 MB
907.6 MB
907.6 MB
907.6 MB
907.6 MB
907.6 MB
907.6 MB
907.6 MB
907.6 MB
907.6 MB
907.7 MB
907.7 MB
907.7 MB
907.7 MB
907.7 MB
907.8 MB
907.8 MB
907.8 MB
907.8 MB
907.8 MB
907.8 MB
907.8 MB
907.8 MB
907.8 MB
907.8 MB
907.8 MB
907.9 MB
907.9 MB
907.9 MB
907.9 MB
907.9 MB
907.9 MB
907.9 MB
907.9 MB
907.9 MB
907.9 MB
907.9 MB
907.9 MB
907.9 MB
907.9 MB
908.0 MB
908.0 MB
908.0 MB
908.1 MB
908.1 MB
908.1 MB
908.2 MB
908.2 MB
908.2 MB

smorad avatar Sep 06 '23 16:09 smorad

Any updates to this issue @yashk2810 ? I'm on jax==0.4.23 and jaxlib==0.4.23+cuda11.cudnn86 (which should include changes from #17452 ) and see the same memory leak issue.

karen-sy avatar Feb 19 '24 16:02 karen-sy

I think this is working as expected: when you run these lines:

iters = 500_000
key, *data_keys = random.split(key, iters + 1)

you're creating a list data_keys with 500,000 elements, each of which is a JAX array (and each of which will have the memory overhead assocated with a concrete JAX array instance). I would not expect the memory footprint to be reduced until those 500,000 arrays are deallocated, for example by deleting data_keys and then calling gc.collect().

Iterating through data_keys does not change or deallocate the contents of data_keys.

Perhaps you were aiming for something like this:

keys = random.split(key, iters + 1)
key, data_keys = keys[0], keys[1:]

in which data_keys will be a single array rather than a list of arrays.

jakevdp avatar Feb 27 '24 17:02 jakevdp

(I just realized my comment is the same solution @mattjj offered above... sorry if there's something I'm missing)

jakevdp avatar Feb 27 '24 18:02 jakevdp