jax icon indicating copy to clipboard operation
jax copied to clipboard

Trying to get a PRNGKey on cpu results in strange gpu memory allocation

Open ysngshn opened this issue 4 years ago • 9 comments
trafficstars

I am trying to get a jax.random.PRNGKey on cpu to manage the source of randomness for some host function. If I understand correctly, one should do something like:

import jax
key = jax.jit(jax.random.PRNGKey, backend='cpu')(0)

This indeed returns a key on cpu, and subsequent operations like jax.random.normal also stays on cpu, so far so good!

However, I just realized that when executing the line key = jax.jit(jax.random.PRNGKey, backend='cpu')(0) in a python interpreter, it results in allocation of ~140 MB memory on GPU, which is unexpected. This is not the 90% memory pre-allocation, which would occupy > 10 GB GPU memory.

Any ideas on why this happens? Is this somehow an expected behavior, or maybe there are better ways to achieve what I need? (I also need to run functions on the gpu so setting JAX to use cpu globally does not solve the issue.)

ysngshn avatar Aug 13 '21 16:08 ysngshn

Hrm, I don't know why this would happen, though I would guess it's an effect of backend-initialization. That is, when you run the first JAX operation (not when JAX is imported), all available backends are initialized. That means that if you haven't already initialized the backends, the first operation on CPU will cause the GPU backend (if available) to be initialized.

Is this 140MB allocation a problem somehow, or are you just curious about it?

mattjj avatar Aug 13 '21 16:08 mattjj

Thanks for the quick reply! This is mainly out of curiosity. Backend initialization indeed seems a plausible explanation. Ideally I guess one wouldn't expect gpu backend be initialized when working only on the cpu backend. Also, the default 90% allocation is not triggered, wouldn't this potentially generate some issues?

ysngshn avatar Aug 13 '21 16:08 ysngshn

Ideally I guess one wouldn't expect gpu backend be initialized when working only on the cpu backend.

We might be able to make this initialization lazier... @skye do you know if there are any reasons not to do that?

Also, the default 90% allocation is not triggered, wouldn't this potentially generate some issues?

Maybe, but I don't know of any it's caused yet! Any guesses?

mattjj avatar Aug 13 '21 16:08 mattjj

Maybe, but I don't know of any it's caused yet! Any guesses?

I do remember having to disable gpu preallocation to avoid some cuda oom error, which is also mentioned in the JAX documentation, but that was for some other code (which also uses jax.random for some dataloader function on cpu) which might be totally unrelated.

Anyway I feel like this is something interesting to point out. Thanks a lot for the quick replies!

ysngshn avatar Aug 13 '21 17:08 ysngshn

I'm guessing you somehow have GPU preallocation disabled. You can check by setting the env var TF_CPP_MIN_LOG_LEVEL=0, which turns on extra JAX logging (and TensorFlow logging, we use the same logging library). After running an operation with jax, it should print:

XLA backend will use up to XXX bytes on device 0 for BFCAllocator.

if preallocation is disabled, or:

XLA backend allocating XXX bytes on device 0 for BFCAllocator.

if preallocation is enabled.

That said, I'm not sure what the 140MB allocation is about. I'm seeing an extra ~105MB allocated too. I turned on a bunch of GPU memory logging with TF_CPP_VMODULE=bfc_allocator=1,stream_executor_pimpl=1, but the output indicates it's only allocating an initial 2MB: https://gist.github.com/skye/931c8371b858cab96bae251652bfdf11 Maybe the CUDA runtime allocates some space for itself?

Ideally I guess one wouldn't expect gpu backend be initialized when working only on the cpu backend.

Agreed. I'm gonna look into not initializing all backends when a specific backend is requested.

skye avatar Aug 13 '21 18:08 skye

I tried the same code in a python interpreter launched with TF_CPP_MIN_LOG_LEVEL=0 python, I am getting a line saying

XLA backend allocating 10416055910 bytes on device 0 for BFCAllocator.

while nvidia-smi tells me that only 140 MB is allocated.

ysngshn avatar Aug 13 '21 20:08 ysngshn

That's very strange. Can you provide the full jax + nvidia-smi output?

skye avatar Aug 13 '21 21:08 skye

That's very strange. Can you provide the full jax + nvidia-smi output?

I have put the jax + nvidia-smi output in this gist

ysngshn avatar Aug 14 '21 12:08 ysngshn

FWIW I've spotted the same thing recently, still not sure what is allocating but I don't think this allocation is coming via the CUDA runtime or driver APIs (i've tried hooking those and printing when allocations happen).

For a minimal repro:

!/usr/lib/libcuda/nvidia-smi
import jax
assert jax.default_backend() == 'gpu'
!/usr/lib/libcuda/nvidia-smi

Maybe the CUDA runtime allocates some space for itself?

I suspect so...

tomhennigan avatar Aug 14 '21 18:08 tomhennigan