anndata
anndata copied to clipboard
Anndata not properly garbage collected
Description When an anndata object is deleted in the current scope the underlying memory is not (reliably) freed as it is for numpy and others. The memory is kept allocated until the process exits. This lead to huge memory consumption if several anndata files are read sequentially in the same process.
Version: Python 3.7.6 and Python 3.8.2 anndata==0.7.1 scanpy==1.4.6
Code to Reproduce:
import scanpy as sc
import numpy as np
import os
import tracemalloc
def display_top(snapshot, key_type='lineno'):
snapshot = snapshot.filter_traces((
tracemalloc.Filter(False, "<frozen importlib._bootstrap>"),
tracemalloc.Filter(False, "<unknown>"),
))
top_stats = snapshot.statistics(key_type)
total = sum(stat.size for stat in top_stats)
return total
def trace_function(func, arg, n):
total = np.zeros(n)
data = func(arg).copy()
tracemalloc.start()
for i in range(n):
data = func(arg).copy()
snapshot = tracemalloc.take_snapshot()
total[i] = display_top(snapshot)
tracemalloc.stop()
return total
_NP_FILENAME = 'numpydata.npy'
_ANNDATA_FILENAME = 'testfile.h5ad'
_RUNS = 5
data = sc.read(_ANNDATA_FILENAME)
np.save(_NP_FILENAME, np.random.rand(*data.shape))
total_np = trace_function(np.load, _NP_FILENAME, _RUNS)
total = trace_function(sc.read, _ANNDATA_FILENAME, _RUNS)
display(total_np)
array([1.74600576e+09, 1.74602235e+09, 1.74602685e+09, 1.74602814e+09,
1.74602672e+09])
display(total)
array([1.74998379e+09, 2.62655985e+09, 3.50312986e+09, 2.62657418e+09,
3.50315218e+09])
import matplotlib.pyplot as plt
plt.plot(np.arange(_RUNS), total_np, label='numpy')
plt.plot(np.arange(_RUNS), total, label='anndata')
plt.legend()
plt.show()
I think I've noticed this before but haven't been able to figure out what's going on. If it were a simple memory leak where the memory use just kept increasing, that would make sense to me, and would obviously be a leak. But the memory sometimes get collected.
Do you have any idea what could be going on?
Thank you for your quick reply. Yes you're right, it's not a normal memory leak. I did a small test today with pympler and it looks like the complete object (and not only parts of it) are kept in memory for some time. However I have no idea why this is the case, but I'll let you know when I have investigated it in more detail.
---- SUMMARY ------------------------------------------------------------------
Before juggling with tracked obj... active 0 B average pct
anndata._core.anndata.AnnData 0 0 B 0 B 0%
First init active 0 B average pct
anndata._core.anndata.AnnData 1 1.65 GB 1.65 GB 73%
0. iteration active 0 B average pct
anndata._core.anndata.AnnData 2 3.30 GB 1.65 GB 107%
1. iteration active 0 B average pct
anndata._core.anndata.AnnData 3 4.95 GB 1.65 GB 127%
2. iteration active 0 B average pct
anndata._core.anndata.AnnData 4 6.61 GB 1.65 GB 139%
3. iteration active 0 B average pct
anndata._core.anndata.AnnData 5 8.26 GB 1.65 GB 148%
4. iteration active 0 B average pct
anndata._core.anndata.AnnData 6 9.91 GB 1.65 GB 155%
5. iteration active 0 B average pct
anndata._core.anndata.AnnData 7 11.56 GB 1.65 GB 160%
6. iteration active 0 B average pct
anndata._core.anndata.AnnData 2 3.30 GB 1.65 GB 105%
7. iteration active 0 B average pct
anndata._core.anndata.AnnData 3 4.95 GB 1.65 GB 125%
8. iteration active 0 B average pct
anndata._core.anndata.AnnData 4 6.61 GB 1.65 GB 138%
9. iteration active 0 B average pct
anndata._core.anndata.AnnData 5 8.26 GB 1.65 GB 147%
-------------------------------------------------------------------------------
One point I would like to know: can your machine run out of memory due to objects which could be collected? Is it possible that collection of these objects is triggered based on need?
Yes It can. Not in the single-processed example I showed above, but when I do it multiprocessed my machine gets out-of-memory. See the code below for example. This fills 32 GB of RAM quite fast with a anndata file which takes 1.65 GB in RAM usually.
_RUNS = 3
def trace_function():
data = anndata.read_h5ad(_ANNDATA_FILENAME)
for i in range(_RUNS):
data = anndata.read_h5ad(_ANNDATA_FILENAME)
return 0
import multiprocessing as mp
with mp.Pool(3) as pool:
results = pool.starmap_async(trace_function, [() for _ in range(100)])
print(results.get())
There is a workaround by setting mp.Pool(3,maxtasksperchild=1), meaning when A a (sub)process exits the memory is garbage collected as expected but if the processes are reused the anndata objects still accumulate.
The issue seems to be circular references, which cannot be resolved with the standard reference counter garbage collector. They can only be resolved when with the full garbage collector (gc.collect()). This somehow explains the weird behaviour, because in the single-process example the full garbage collector is called, when the allocated memory reaches some limits. However in the multi-processed example, for each process the limit is not reached but in total all RAM is allocated.
The issue seem to be _parent in anndata._core.aligned_mapping.Layers, adata in anndata._core.file_backing.AnnDataFileManager and probably some more, which you can show by:
import gc
gc.get_referrers(data)
In total I get 6 circular and 1 actual reference for the anndata object. The best thing, in my opinion, would be store only weak references here (or drop them completely if possible), but weak references seems prevent pickling of the object.
Glad to see you open a PR for this!
Just a general point about this as an issue though. Unless this memory usage is being triggered by non-pathological code, I'm not sure it justifies huge changes in APIs or intended behavior to fix it.
In practice, I think the gc runs frequently enough that these objects get collected. All the example code we have here is pretty weird, just allocating large amounts of memory and doing nothing with it. From what I can see, making a plot or running a pca seems to reliably deallocate the memory.
I'd definitely like to see this fixed, but I would like to be conservative about how it's fixed.
Thank you again for your quick responses and willingness to help.
Just a general point about this as an issue though. Unless this memory usage is being triggered by non-pathological code, I'm not sure it justifies huge changes in APIs or intended behavior to fix it.
I can understand that you don't want any changes in APIs or similar and I'm happy to take your ideas, such that we can avoid this.
In practice, I think the gc runs frequently enough that these objects get collected. All the example code we have here is pretty weird, just allocating large amounts of memory and doing nothing with it.
Sure these code snippets are very pathological, but we see this behaviour also in non-pathological situations/ pipelines where we process or copy anndata objects a lot. However these codes are obviously to large to show here.
From what I can see, making a plot or running a pca seems to reliably deallocate the memory.
Additionally I cannot agree, that running pca is reliably deallocating the memory. If you modify the above snippet, this still gradually fills memory. And I think doing multithreaded 100 PCA's is not as pathological.
import os
import anndata
import scanpy as sc
import multiprocessing as mp
def do_pca(filename):
data = anndata.read_h5ad(filename)
sc.tl.pca(data)
data.write('tmpfile.h5ad')
os.remove('tmpfile.h5ad')
return 0
with mp.Pool(3) as pool:
results = pool.starmap_async(do_pca, [(_ANNDATA_FILENAME,) for _ in range(100)])
print(results.get())
Those examples sound more reasonable. I just gave this case a try with your PR and the current master. I wasn't able to see any consistent improvement in memory usage on your PR branch though. I'm not completely sure what to make of all this though, since I've had generally negative experiences with memory usage and multiprocessing. Here are the results I recorded:
The script I used
Data was generated with:
from scipy import sparse
import scanpy as sc
(
sc.AnnData(sparse.random(50000, 10000, format="csr"))
.write_h5ad("test_sparse.h5ad", compression="lzf")
)
I was also using scanpy master for the efficient sparse pca implementation.
import os
os.environ["OMP_NUM_THREADS"] = "4"
import anndata
import scanpy as sc
import multiprocessing as mp
import gc
_ANNDATA_FILENAME = "./test_sparse.h5ad"
def do_pca(filename):
data = anndata.read_h5ad(filename)
data.layers["dense"] = data.X.toarray() # To increase memory usage
sc.tl.pca(data)
return 0
with mp.Pool(2) as pool:
results = pool.starmap_async(do_pca, [(_ANNDATA_FILENAME,) for _ in range(30)])
print(results.get())
Of course, YMMV. Have you seen memory usage improvements for your workflows using your PR?
Yes I got improved memory usage. I now tested with the master branch of scanpy and got the same results as you. I guess it's not completely fixed and pca may stores somewhere circular references in the anndata object ? Strangely I didn't saw this behaviour with stable scanpy.
I went back to a toy example and replaced sc.tl.pca with time.sleep(1) and got the following results.
I never checked our pipeline with the PR yet, however maxtasksperchild in multiprocessing.Pool was solving the issue for us in several cases.
I went back to a toy example and replaced sc.tl.pca with time.sleep(1) and got the following results.
I figured this one out, and am now able to see improved memory usage. Basically accessing uns
creates circular references, where one of those object holds a reference to the anndata object. It's a very good use case for weakref
, since one of the objects should only ever be referred to by the other. I'll make a quick PR for this, since it's a two line change, and you can merge that into your branch.
however maxtasksperchild in multiprocessing.Pool was solving the issue for us in several cases.
Good to hear. As an aside, I've generally had much better experiences with resource handling through dask
than multiprocessing. This is especially true if task rely on each other, since dask
is good about optimizing for data locality.
Thanks for fixing and this hint about dask!
I just ran in a similar issue with the current anndata version 0.7.5.
Unfortunately I know not enough about pythons garbage collection to fix it myself, but maybe a really trivial example might help you guys fix it:
def do_stuff(adata): # %mprun magic works only if this is defined somewhere else in a file
copy0 = adata.copy()
del copy0
copy1 = adata.copy()
del copy1
copy2 = adata.copy()
del copy2
adata = ad.AnnData(np.ones((10000,10000)))
# copy numpy array
%mprun -f do_stuff do_stuff(adata.X)
# copy anndata
%mprun -f do_stuff do_stuff(adata)
Output for the numpy part:
Line # Mem usage Increment Occurences Line Contents
============================================================
957 581.3 MiB 581.3 MiB 1 def do_stuff(adata):
958 962.7 MiB 381.5 MiB 1 copy0 = adata.copy()
959 581.3 MiB -381.5 MiB 1 del copy0
960 962.6 MiB 381.3 MiB 1 copy1 = adata.copy()
961 581.3 MiB -381.3 MiB 1 del copy1
962 962.6 MiB 381.3 MiB 1 copy2 = adata.copy()
963 581.3 MiB -381.3 MiB 1 del copy2
Output for the anndata part:
Line # Mem usage Increment Occurences Line Contents
============================================================
957 581.3 MiB 581.3 MiB 1 def do_stuff(adata):
958 963.1 MiB 381.8 MiB 1 copy0 = adata.copy()
959 963.1 MiB 0.0 MiB 1 del copy0
960 1345.1 MiB 382.1 MiB 1 copy1 = adata.copy()
961 1345.1 MiB 0.0 MiB 1 del copy1
962 1727.0 MiB 381.8 MiB 1 copy2 = adata.copy()
963 1727.0 MiB 0.0 MiB 1 del copy2
With some luck, in some runs some partial garbage collection takes place randomly, but over time some net leakage remains.
As it was mentioned that the issues would not occur in real life, some words about my use case. I have a couple GB adata and run a method for O(10) different parameter sets. Every single call of the method should not change the original adata, and therefore uses a local copy to do its thing. All (many) of these local copies accumulate in memory eventually crashing the program.
The gc.collect()
call (thanks @fhausmann ) fixes my real life scenario as does replacing .X in the copies with an empty sparse matrix, but having to do this every time one uses a temporary anndata is certainly not ideal...
For others that end up on this thread - I had a similar memory leak in version 0.8.0 and handled it by converting my data to a sparse pandas data frame.
I am having the same issue in 0.9.1 when I loop over a few anndata
objects. Even when I explicitly del
the object every loop iteration, the problem persists. Per @fhausmann, using gc.collect()
solves the problem.
I am having the same issue in 0.9.1 when I loop over a few
anndata
objects. Even when I explicitlydel
the object every loop iteration, the problem persists. Per @fhausmann, usinggc.collect()
solves the problem.
This works for me. I tried both adata = None then gc.collect() and del adata then gc.collect(), and only the latter one works.