jax icon indicating copy to clipboard operation
jax copied to clipboard

Clear GPU memory

Open clemisch opened this issue 4 years ago • 21 comments

Dear jax team,

I'd like to use jax alongside other tools running on GPU in the same pipeline. Is there a possibility to "encapsulate" the usage of jax/XLA so that the GPU is freed afterwards? Even if I would have to copy over the DeviceArrays into numpy manually.

Maybe something like:

with jax.Block():
    result = some_jitted_fun(a, b, c)
    result = onp.copy(result)

I can imagine the (design of) handling of objects and their GPU memory is not straightforward, if not practically impossible. Could I at least tell jax to use the GPU only incrementally instead of filling the memory completely on import?

clemisch avatar Aug 21 '19 14:08 clemisch

Nice idea! We've had a few related requests recently, and I think we can provide better tools here. (Actually, JAX is pretty tiny, and the way it handles GPU memory (and all backend memory) is pretty straightforward, so we should have the right tools at our disposal!)

As to freeing up memory completely on import, though, have you taken a look at the GPU memory allocation note in the docs? You can prevent JAX from allocating everything up-front, or even control the fraction of GPU memory it allocates up-front. Could that help?

mattjj avatar Aug 21 '19 14:08 mattjj

That's great to hear, thank you! A programmatical solution sometime in the future would be very cool, but I think XLA_PYTHON_CLIENT_PREALLOCATE=false could do the trick for now. But I assume only affects pre-allocation, not freeing the memory afterwards?

I use jax in designated "blocks" in the pipeline, so freeing and re-allocating memory should be not that bad for performance in my usecase (if it is possible).

Edit: In my usecase, the memory is not freed with XLA_PYTHON_CLIENT_PREALLOCATE=false. This leaves the GPU useable in principle by other tools, but it's not a great solution TBH. So the programmatical solution would be very cool! :wink:

clemisch avatar Aug 21 '19 14:08 clemisch

But I assume only affects pre-allocation, not freeing the memory afterwards?

Device memory for an array ought to be freed once all Python references to it drop, i.e. upon destruction of any corresponding DeviceArray. You could encourage this explicitly with del my_device_array, if Python scope isn't already lined up with your pipeline "blocks."

In your example, the line

result = onp.copy(result)

will drop the only reference to a DeviceArray (from the previous line), and should clear the device memory associated with the value of some_jitted_fun(a, b, c), for the same reason.

froystig avatar Aug 21 '19 15:08 froystig

Thank you @froystig, that sounds like a great pythonic solution! However, I don't see that behavior. In the snippet below, GPU memory is not freed after del arr. Am I missing something?

%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=0
%env XLA_PYTHON_CLIENT_PREALLOCATE=false

import jax.numpy as np

arr = np.arange(int(1e9))
arr += 1

del arr

clemisch avatar Aug 21 '19 15:08 clemisch

XLA_PYTHON_CLIENT_PREALLOCATE=false does only affect pre-allocation, so as you've observed, memory will never be released by the allocator (although it will be available for other DeviceArrays in the same process).

You could try setting XLA_PYTHON_CLIENT_ALLOCATOR=platform instead. This will be slower, but will also actually deallocate when a DeviceArray's buffer is released. I forgot to mention this in the GPU memory allocation note, my bad! I'll update the note to include this.

skye avatar Aug 21 '19 16:08 skye

Thank you @skye, that solved it!

clemisch avatar Aug 21 '19 18:08 clemisch

I'm not against closing this, but since the runtime does increase quite a bit (17s vs 13s in my case, 30% slower), I think there is at least some demand for having blockwise preallocation with clearing the memory afterwards (similar to my crude snippet in the beginning).

Instead of, you know, instantly clearing memory once a function (for example) returns. For some nested framework this makes a noticeable performance difference.

clemisch avatar Aug 21 '19 18:08 clemisch

Would it work to have a (slow) function call that tries to free unused memory? I say "try" because the default allocator allocates large regions, with multiple DeviceArrays possibly occupying a single region, so freeing one DeviceArray may not allow us to free the whole region. We could also have an even slower function that copies DeviceArrays around to free as much memory as possible.

I'm also not sure what you mean exactly by blockwise preallocation, can you explain the API you have in mind?

skye avatar Aug 21 '19 19:08 skye

Thank you for your feedback! If "slow" means in the order of 0.1s to 1s that would be great.

Although I would not call my crude idea an API, I thought of using that with block to enable behavior like XLA_PYTHON_CLIENT_PREALLOCATE=false with XLA_PYTHON_CLIENT_ALLOCATOR=default within that block. After the block, the memory can be garbage collected like XLA_PYTHON_CLIENT_ALLOCATOR=platform. Maybe this would be similar to your first suggestion.

With my very limited understanding of XLA memory handling this would incrementally use GPU memory without the need to free every little piece after each small inner function call (fragmentation (?)), resulting in better performance. When done with the work and having a handfull of result arrays, copy them to host and free GPU. Please feel free to tell me if this does not make sense!

clemisch avatar Aug 21 '19 20:08 clemisch

NB:the below is probably completely obsolete now -- levskaya 2021.02.18

I was asked to post a utility function I use to delete DeviceArrays in colab:

import gc
import jax
def reset_device_memory(delete_objs=True):
  """Free all tracked DeviceArray memory and delete objects.
  Args:
    delete_objs: bool: whether to delete all live DeviceValues or just free.
  Returns:
    number of DeviceArrays that were manually freed.
  """
  dvals = (x for x in gc.get_objects() if isinstance(x, jax.xla.DeviceValue))
  n_deleted = 0
  for dv in dvals:
    if not isinstance(dv, jax.xla.DeviceConstant):
      try:
        dv._check_if_deleted()  # pylint: disable=protected-access
        dv.delete()
        n_deleted += 1
      except ValueError:
        pass
    if delete_objs:
      del dv
  del dvals
  gc.collect()
  return n_deleted

this and some memory reporting utils in a gist: https://gist.github.com/levskaya/37f72b76bd5c72f9e5ce48ce154a9246 and in a public colab: https://colab.research.google.com/drive/1odOdMbbp-47WyDhjIfTDWukOBTSUt5Q6

levskaya avatar Mar 11 '20 14:03 levskaya

FYI: the memory counting in the linked Gist is no longer completely accurate because e.g. jax.numpy.zeros does not actually allocate memory but it's counted (you can make the reported memory usage arbitrarily large).

jonasrauber avatar May 02 '20 09:05 jonasrauber

@levskaya The object_memory_usage() function in your script shows the followng error :

AttributeError: module 'jax.interpreters.xla' has no attribute 'DeviceValue'

Any idea how to resolve this/what DeviceValue has been replaced with newer versions of Jax?

Jeevesh8 avatar Feb 19 '21 07:02 Jeevesh8

@Jeevesh8 - there have been massive changes in the underlying JAX infrastructure since the above was written (incl more python logic moved into C++), I believe some new memory utilities are being written to work with the newer system.

levskaya avatar Feb 19 '21 08:02 levskaya

@levskaya Can you point me somewhere I can find these? I would be grateful.

Jeevesh8 avatar Feb 19 '21 15:02 Jeevesh8

If you'd like to delete all on-device buffers, you can now reach directly into the underlying C++ internals like this:

backend = jax.lib.xla_bridge.get_backend()
for buf in backend.live_buffers(): buf.delete()

You'll get errors about trying to use a deleted buffer if you try to use an existing references to these buffers. Note these APIs are still subject to change as well!

skye avatar Feb 19 '21 16:02 skye

@skye , In Jax, I noticed a very peculiar thing: If there is a nested dictionary.. and if I assign a jax.numpy.ndarray at a leaf node, then, even after there is no reference to the outermost dicitonary, the arrays persist in GPU memory.

This array was created outside of a pure function. Is this intended behavior? Why?

You seem to be a very knowledgeable person(regarding JAX). Thank you for all your comments on other issues. And reply to this.

Yours gratefully,

Jeevesh8 avatar Feb 19 '21 17:02 Jeevesh8

I'm trying to write some JAX tests with the pytest framework. My tests run sequentially and I have some TensorFlow and JAX tests. After running the JAX tests the GPU memory is not cleared which can lead to problems when running the TensorFlow tests. I tried setting the environment variables in a setup functions but such that no memory gets preallocated (XLA_PYTHON_CLIENT_ALLOCATOR=platform) this doesn't work in my case. The reason for this is that some of my functions/classes have default JAX values that are getting called before the setup functions are getting called. See an example below:

import os
import time

import jax.numpy as jnp


def setup_module(module):
    os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"


def teardown_module(module):
    del os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"]


def f(x: jnp.ndarray, c: jnp.ndarray = jnp.ones([1])) -> jnp.ndarray:
    return x + c


def test_f():
    x = jnp.array([1, 2])
    f(x, jnp.ones([1]))
    # To check GPU memory usage
    time.sleep(5.0)

c: jnp.ndarray = jnp.ones([1]) already triggers the preallocation before running setup_module. Is there any way to clear the memory or circumvent this issue? Furthermore, it would be great to have a possibility run some integration tests with preallocated memory for performance reasons.

tetterl avatar Apr 06 '21 11:04 tetterl

Have you tried also setting preallocate to false?

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"  # add this
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"

clemisch avatar Apr 06 '21 11:04 clemisch

@clemisch Yes I also tried this. If I remove the default value in f it works.

tetterl avatar Apr 06 '21 13:04 tetterl

Any update on this?

jaanli avatar Feb 16 '22 16:02 jaanli

To prevent preallocation of gpu memory triggered by jnp.ones([1]) I had to escape quotes in the env variable value to make it work: os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = '\"platform\"' or alternatively in the notebook you can set it with: %env XLA_PYTHON_CLIENT_ALLOCATOR="platform"

kostyaev avatar Jul 14 '22 08:07 kostyaev

Somewhat related to #18181 and #5179, trying to make jax more cooperative with respect to other CUDA libraries.

pwuertz avatar Oct 20 '23 10:10 pwuertz