jax icon indicating copy to clipboard operation
jax copied to clipboard

Allow preallocation of GPU memory based on total memory

Open gehring opened this issue 5 years ago • 5 comments

Preallocating memory based on available memory easily leads to a race condition when launching several jax processes that share a GPU. It would be useful to be able to set how much memory to preallocate in a way that is invariant to the order processes are launched. Being able to either specify an absolute amount of memory or a fraction of the total gpu memory would solve this with minimal changes to JAX/XLA.

gehring avatar Sep 16 '20 19:09 gehring

Are you saying that the current "preallocation fraction" environment variable is actually a "preallocate this fraction of the remaining available memory after previous processes have claimed their own"?

CC @skye, who knows way more about this than me.

jekbradbury avatar Sep 17 '20 06:09 jekbradbury

Yes, as far as I can tell, it is preallocating based on a fraction of the memory that hasn't already been allocated to an existing process. My experience is that launching, say, 6 process with a 0.15 fraction will result in an uneven allocation of gpu memory, i.e., not every process will be given the same amount of gpu memory. This behavior is consistent with how I interpret the docs which uses the words "currently-available GPU memory".

gehring avatar Sep 17 '20 20:09 gehring

@gehring your understanding is correct! It shouldn't be too bad to add a separate option for allocating an absolute percentage, although will require a C++ change to TF (around here). If someone is willing to dive in, I don't think the actual change should be too bad and I can help shepherd any JAX and/or TF PRs. Otherwise I can try to get to this myself.

Alternatively, it looks like it'll already do what you want if you enable unified memory via the env var TF_FORCE_UNIFIED_MEMORY=1. I'm not very familiar with unified memory, but maybe worth a shot as a quick workaround?

skye avatar Sep 17 '20 21:09 skye

Allocating either a fraction of the total memory or a fixed number of GB would be very useful. I found this issue through a similar experience to @gehring .

I have a GPU with 32GB of memory, and I hoped to run 3 processes each with ~10GB of memory. When I run all of them with XLA_PYTHON_CLIENT_MEM_FRACTION=0.3, I can see that each process allocates 30% of the remaining memory:

proc0 - gets 30% of the 100% remaining avail memory = 30% of total memory = 9.6GB
proc1 - gets 30% of the 70% remaining avail memory = 21% of total memory = 6.7GB
proc2 - gets 30% of the 49% remaining avail memory = 15% of total memory = 4.7GB

eamartin avatar Aug 12 '21 21:08 eamartin

Hi @jekbradbury @skye any updates on this?

llan-ml avatar May 29 '22 04:05 llan-ml

Closed by https://github.com/tensorflow/tensorflow/pull/58638

nouiz avatar Nov 28 '22 19:11 nouiz