jax icon indicating copy to clipboard operation
jax copied to clipboard

Memory leak when calling jax script in python

Open GitInTheRobot opened this issue 5 years ago • 5 comments

Is there a way either using a JAX command or a command in cuda to free GPU memory? I'm having a problem with a memory leak in a script like the following

for _ in range(10): 
    doJax(inputs)

which gives an out of memory error a few calls in despite there being ostensibly no residual GPU memory usage. I'm aware of the GPU memory allocation note, but I'm not looking to change preallocation behavior. I merely want to replicate what would happen if I called doJax from a shell 10 times (which does work without issue).

Error trace:

  File "runClassificationSeries.py", line 371, in <module>
    func_list[int(sys.argv[1])]()
  File "runClassificationSeries.py", line 365, in trainAccel4
    train_models.train_model(model, process_params, in_P, labs)
  File "/home/Projects/scripts/../src/train_models.py", line 366, in train_model
    loss, error, loss_state, err_state, init_state, del_loss = train(model, train_d, test_d, shape_type, n
orm_funcs, preProcessParams, in_P[i])
  File "/home/Projects/scripts/../src/train_models.py", line 214, in train
    train_ds, params['batch_size'], epoch, input_rng, params['l2_regularization'], learning_rate_fn)
  File "/home/Projects/scripts/../src/train_models.py", line 129, in train_epoch
    new_model_state, optimizer, metrics, lr = train_step(state.model_state, state.optimizer, rng, state.s$
ep, batch, l2_reg, loss_function, metric_function, learning_rate_fn)
  File "/opt/conda/lib/python3.7/site-packages/jax/api.py", line 170, in f_jitted
    name=flat_fun.__name__, donated_invars=donated_invars)
  File "/opt/conda/lib/python3.7/site-packages/jax/core.py", line 1098, in call_bind
    outs = primitive.impl(fun, *args, **params)
  File "/opt/conda/lib/python3.7/site-packages/jax/interpreters/xla.py", line 540, in _xla_call_impl
    return compiled_fun(*args)
  File "/opt/conda/lib/python3.7/site-packages/jax/interpreters/xla.py", line 770, in _execute_compiled
    out_bufs = compiled.execute(input_bufs)
RuntimeError: Internal: Failed to load in-memory CUBIN: CUDA_ERROR_OUT_OF_MEMORY: out of memory

System info: Debian 10 Cuda 10.1.243 jax 0.1.73 python 3.7.6 jaxlib 0.1.51 Although I've tried on different OSs and versions of the above and problem persists.

Thanks for your help

GitInTheRobot avatar Jul 28 '20 18:07 GitInTheRobot

What does the doJax function contain?

jakevdp avatar Jul 28 '20 19:07 jakevdp

What does the doJax function contain?

It calls some flax modules to train a neural network. I posted here because I assumed that memory deallocation is handled/handle-able through JAX.

GitInTheRobot avatar Jul 28 '20 20:07 GitInTheRobot

Yes, memory allocation/deallocation should be handle-able through JAX, but I cannot reproduce the issue. For example, this works fine, despite the total size of the created arrays being larger than the available GPU memory:

import jax.numpy as jnp
from jax import jit, random

@jit
def doJax(key):
  key = random.PRNGKey(key)
  x = random.uniform(key, (5000, 5000), dtype=jnp.float32)
  y = random.uniform(key, (5000, 50000), dtype=jnp.float32)
  z = x @ y
  return z.mean()

for key in range(100):
  print(doJax(key))

Can you provide a minimal example of a code snippet that leads to the memory leak issue you're seeing?

jakevdp avatar Jul 28 '20 20:07 jakevdp

I'll see if I can create it in the future, but it is a fairly large, branching set of code so it will take some time to distill the problematic parts. I don't believe the leak to be a general feature of JAX code (the above example runs fine), but something more specific to my flax-based methods. My primary goal was to see if there were any JAX-related commands available to deallocate all GPU memory and "reset" itself.

GitInTheRobot avatar Jul 29 '20 02:07 GitInTheRobot

Something that would be interesting would be to see if the device memory profiling support shows where the memory usage is coming from:

https://jax.readthedocs.io/en/latest/device_memory_profiling.html

hawkinsp avatar Jul 29 '20 20:07 hawkinsp

Since we weren't able to come up with a reproducer, I'm going to close this issue as stale (after 2 years!). A lot has changed since then, including more cache limits. Please open new issues if memory leaks seem to arise!

mattjj avatar Aug 16 '22 18:08 mattjj