jax
jax copied to clipboard
Memory leak when calling jax script in python
Is there a way either using a JAX command or a command in cuda to free GPU memory? I'm having a problem with a memory leak in a script like the following
for _ in range(10):
doJax(inputs)
which gives an out of memory error a few calls in despite there being ostensibly no residual GPU memory usage. I'm aware of the GPU memory allocation note, but I'm not looking to change preallocation behavior. I merely want to replicate what would happen if I called doJax from a shell 10 times (which does work without issue).
Error trace:
File "runClassificationSeries.py", line 371, in <module>
func_list[int(sys.argv[1])]()
File "runClassificationSeries.py", line 365, in trainAccel4
train_models.train_model(model, process_params, in_P, labs)
File "/home/Projects/scripts/../src/train_models.py", line 366, in train_model
loss, error, loss_state, err_state, init_state, del_loss = train(model, train_d, test_d, shape_type, n
orm_funcs, preProcessParams, in_P[i])
File "/home/Projects/scripts/../src/train_models.py", line 214, in train
train_ds, params['batch_size'], epoch, input_rng, params['l2_regularization'], learning_rate_fn)
File "/home/Projects/scripts/../src/train_models.py", line 129, in train_epoch
new_model_state, optimizer, metrics, lr = train_step(state.model_state, state.optimizer, rng, state.s$
ep, batch, l2_reg, loss_function, metric_function, learning_rate_fn)
File "/opt/conda/lib/python3.7/site-packages/jax/api.py", line 170, in f_jitted
name=flat_fun.__name__, donated_invars=donated_invars)
File "/opt/conda/lib/python3.7/site-packages/jax/core.py", line 1098, in call_bind
outs = primitive.impl(fun, *args, **params)
File "/opt/conda/lib/python3.7/site-packages/jax/interpreters/xla.py", line 540, in _xla_call_impl
return compiled_fun(*args)
File "/opt/conda/lib/python3.7/site-packages/jax/interpreters/xla.py", line 770, in _execute_compiled
out_bufs = compiled.execute(input_bufs)
RuntimeError: Internal: Failed to load in-memory CUBIN: CUDA_ERROR_OUT_OF_MEMORY: out of memory
System info: Debian 10 Cuda 10.1.243 jax 0.1.73 python 3.7.6 jaxlib 0.1.51 Although I've tried on different OSs and versions of the above and problem persists.
Thanks for your help
What does the doJax function contain?
What does the
doJaxfunction contain?
It calls some flax modules to train a neural network. I posted here because I assumed that memory deallocation is handled/handle-able through JAX.
Yes, memory allocation/deallocation should be handle-able through JAX, but I cannot reproduce the issue. For example, this works fine, despite the total size of the created arrays being larger than the available GPU memory:
import jax.numpy as jnp
from jax import jit, random
@jit
def doJax(key):
key = random.PRNGKey(key)
x = random.uniform(key, (5000, 5000), dtype=jnp.float32)
y = random.uniform(key, (5000, 50000), dtype=jnp.float32)
z = x @ y
return z.mean()
for key in range(100):
print(doJax(key))
Can you provide a minimal example of a code snippet that leads to the memory leak issue you're seeing?
I'll see if I can create it in the future, but it is a fairly large, branching set of code so it will take some time to distill the problematic parts. I don't believe the leak to be a general feature of JAX code (the above example runs fine), but something more specific to my flax-based methods. My primary goal was to see if there were any JAX-related commands available to deallocate all GPU memory and "reset" itself.
Something that would be interesting would be to see if the device memory profiling support shows where the memory usage is coming from:
https://jax.readthedocs.io/en/latest/device_memory_profiling.html
Since we weren't able to come up with a reproducer, I'm going to close this issue as stale (after 2 years!). A lot has changed since then, including more cache limits. Please open new issues if memory leaks seem to arise!