jax
jax copied to clipboard
Trying to get a PRNGKey on cpu results in strange gpu memory allocation
I am trying to get a jax.random.PRNGKey on cpu to manage the source of randomness for some host function. If I understand correctly, one should do something like:
import jax
key = jax.jit(jax.random.PRNGKey, backend='cpu')(0)
This indeed returns a key on cpu, and subsequent operations like jax.random.normal also stays on cpu, so far so good!
However, I just realized that when executing the line key = jax.jit(jax.random.PRNGKey, backend='cpu')(0) in a python interpreter, it results in allocation of ~140 MB memory on GPU, which is unexpected. This is not the 90% memory pre-allocation, which would occupy > 10 GB GPU memory.
Any ideas on why this happens? Is this somehow an expected behavior, or maybe there are better ways to achieve what I need? (I also need to run functions on the gpu so setting JAX to use cpu globally does not solve the issue.)
Hrm, I don't know why this would happen, though I would guess it's an effect of backend-initialization. That is, when you run the first JAX operation (not when JAX is imported), all available backends are initialized. That means that if you haven't already initialized the backends, the first operation on CPU will cause the GPU backend (if available) to be initialized.
Is this 140MB allocation a problem somehow, or are you just curious about it?
Thanks for the quick reply! This is mainly out of curiosity. Backend initialization indeed seems a plausible explanation. Ideally I guess one wouldn't expect gpu backend be initialized when working only on the cpu backend. Also, the default 90% allocation is not triggered, wouldn't this potentially generate some issues?
Ideally I guess one wouldn't expect gpu backend be initialized when working only on the cpu backend.
We might be able to make this initialization lazier... @skye do you know if there are any reasons not to do that?
Also, the default 90% allocation is not triggered, wouldn't this potentially generate some issues?
Maybe, but I don't know of any it's caused yet! Any guesses?
Maybe, but I don't know of any it's caused yet! Any guesses?
I do remember having to disable gpu preallocation to avoid some cuda oom error, which is also mentioned in the JAX documentation, but that was for some other code (which also uses jax.random for some dataloader function on cpu) which might be totally unrelated.
Anyway I feel like this is something interesting to point out. Thanks a lot for the quick replies!
I'm guessing you somehow have GPU preallocation disabled. You can check by setting the env var TF_CPP_MIN_LOG_LEVEL=0, which turns on extra JAX logging (and TensorFlow logging, we use the same logging library). After running an operation with jax, it should print:
XLA backend will use up to XXX bytes on device 0 for BFCAllocator.
if preallocation is disabled, or:
XLA backend allocating XXX bytes on device 0 for BFCAllocator.
if preallocation is enabled.
That said, I'm not sure what the 140MB allocation is about. I'm seeing an extra ~105MB allocated too. I turned on a bunch of GPU memory logging with TF_CPP_VMODULE=bfc_allocator=1,stream_executor_pimpl=1, but the output indicates it's only allocating an initial 2MB: https://gist.github.com/skye/931c8371b858cab96bae251652bfdf11
Maybe the CUDA runtime allocates some space for itself?
Ideally I guess one wouldn't expect gpu backend be initialized when working only on the cpu backend.
Agreed. I'm gonna look into not initializing all backends when a specific backend is requested.
I tried the same code in a python interpreter launched with TF_CPP_MIN_LOG_LEVEL=0 python, I am getting a line saying
XLA backend allocating 10416055910 bytes on device 0 for BFCAllocator.
while nvidia-smi tells me that only 140 MB is allocated.
That's very strange. Can you provide the full jax + nvidia-smi output?
That's very strange. Can you provide the full jax + nvidia-smi output?
I have put the jax + nvidia-smi output in this gist
FWIW I've spotted the same thing recently, still not sure what is allocating but I don't think this allocation is coming via the CUDA runtime or driver APIs (i've tried hooking those and printing when allocations happen).
For a minimal repro:
!/usr/lib/libcuda/nvidia-smi
import jax
assert jax.default_backend() == 'gpu'
!/usr/lib/libcuda/nvidia-smi
Maybe the CUDA runtime allocates some space for itself?
I suspect so...