numba.jit memory leak
The following loop produces a memory leak. If nb.jit is taken out of the loop, there is no leak.
for i in range(20000000):
@nb.jit(nopython=True, nogil=True)
def testNumba(x):
return np.sum(x)
a = np.ones(100)
testNumba(a)
I'm using Numba v0.20.0 by the way.
Do you have a real-world workload where this is affecting you? We try to release resources when we can, but I wouldn't be surprised if we don't release all memory (especially because LLVM itself may prevent us from that).
Thanks for replying! I have a dll that embeds Python into a system that is up for days. A memory leak, though small, could mean I'd have to reset the system every few days.
Well, if you compile a function in a loop, you will lose a lot of time compiling, so it's much better to move the @jit call out of loop anyway. The only way this could hurt you is if you continuously generate new functions to be compiled, I think.
That's right. In my case there is a Python function being called, say, every few minutes by the enclosing system (which is not aware of the Python details). This function in turn uses @jit. So it's not as bad as a very rapid loop doing jits. I could solve this by keeping a state between function calls as you suggest. I guess I wouldn't prioritize this issue very high, but it would be nice to know there are no leaks to require such workarounds :)
I've tried this again. Here is a script that shows a bit more useful info during running (and also calls gc.collect() from time to time): https://gist.github.com/pitrou/bada94f31e1c8891585c
In 1000 iterations the memory consumption indicated under Linux (the RSS column in ps or top) grows from 65 to 117MB. This is about 50kB per compiled function. There doesn't seem to be any Python-facing memory leak, since sys.getallocatedblocks() remains approximately stable.
The "leak" comes either from memory allocated and never released by LLVM, or some mishandling of LLVM's lifetime and ownership rules on our part...
I found a similar memory leak issue when putting a jitted function inside another function and calling that from a loop. There is no leak when the jitted function is pulled out of the outer function. See file for leak and workaround: https://gist.github.com/ajkur/b5e1f7908d1d2a1f2c43b82690cc82e9
I've stumbled into this when jitting closures using numba 0.40
Thank you @ajkur. I had a scenario like this:
@njit
def fun1():
for a in fun2():
# do something
return
@njit
def fun2():
for i in range(100):
yield i
After removing the second function by placing its contents directly into the first, the memory over time went from this:

to this:

I'm also having an issue with memory leaks with numba==0.50.1. I'm using a thread pool to parallelize computation that releases the GIL. I've included an example below.
It seems that prange does not have this issue; its memory usage is constant, whereas the threading solution seems to use memory proportional to the number work items. I've also tried multiprocessing.pool.ThreadPool but the issue persists. Note that I ran each test_function(...) in a separate interpreter to produce these plots.


import matplotlib.pyplot as plt
import numpy as np
from numba import njit, prange, set_num_threads
from joblib import Parallel, delayed
from memory_profiler import memory_usage
@njit(nogil=True)
def worker(x, nsamp, seed):
np.random.seed(seed)
m, n = x.shape
idx = np.random.choice(m, nsamp)
xsamp = x[idx]
dists = np.zeros(nsamp)
for i in range(nsamp):
for j in range(i + 1, nsamp):
d = 0.0
for k in range(n):
tmp = xsamp[i, k] - xsamp[j, k]
d += tmp * tmp
dist = np.sqrt(d)
dists[i] += dist
dists[j] += dist
entry = xsamp[dists.argmin()]
total = 0.0
for i in range(m):
d = 0.0
for k in range(n):
tmp = x[i, k] - entry[k]
d += tmp * tmp
dist = np.sqrt(d)
total += dist
return entry, total
def use_threading(n=100000, nsamp=1000, seed=0, n_jobs=10, size=1000000):
x = np.random.RandomState(seed).rand(size, 2)
def wrapper(i):
return worker(x, nsamp, i)
with Parallel(n_jobs, backend='threading') as parallel:
return parallel((delayed(wrapper)(i) for i in range(n)))
@njit(nogil=True, parallel=True)
def use_prange(n=100000, nsamp=1000, seed=0, n_jobs=10, size=1000000):
np.random.seed(seed)
set_num_threads(n_jobs)
x = np.random.rand(size, 2)
for i in prange(n):
worker(x, nsamp, i)
def test_function(func):
mem = memory_usage(func, include_children=True, multiprocess=True)
time = np.arange(len(mem)) / 10
fig, ax = plt.subplots()
ax.plot(time, mem)
ax.set(xlabel='time (s)', ylabel='Memory usage (MiB)', title=func.__name__)
plt.show()
if __name__ == '__main__':
'''Main'''
test_function(use_prange)
test_function(use_threading)
@trianta2 thanks for reporting this, I have tried this locally and can confirm I am able to reproduce your findings.
Just a simple example to reproduce
import gc
import numpy as np
import numba
@numba.jit(cache=False)
def main(x):
x**2
gc.enable( )
gc.set_debug(gc.DEBUG_SAVEALL)
x = np.random.random((10,))
main(x)
gc.collect()
assert not len(gc.garbage), f"found {len(gc.garbage)} objects."
raises AssertionError
@Enolerobotti thank you for submitting this. I removed all Numba references from your example and I also receive an AssertionError, so I am unsure what the example you posted is testing..? Can you shed some light on this?
💣 zsh» cat issue_1361.py
import gc
gc.enable( )
gc.set_debug(gc.DEBUG_SAVEALL)
gc.collect()
assert not len(gc.garbage), f"found {len(gc.garbage)} objects."
💣 zsh» python issue_1361.py
Traceback (most recent call last):
File "/Users/esc/git/numba/issue_1361.py", line 5, in <module>
assert not len(gc.garbage), f"found {len(gc.garbage)} objects."
AssertionError: found 51 objects.