cellpose
cellpose copied to clipboard
Memory usage increases during subsequent evaluations of cellpose model
Hi there, and thanks for your support.
While working on a different project, with @mfranzon and @jluethi we noticed an unexpected increase of RAM usage during subsequent runs of cellpose segmentation with the nuclei
model. I'll report here an example which is as self-contained as possible, but other pieces of information are scattered in our original issues.
The question is whether this behavior looks expected/normal, or whether we could try to mitigate it. Also we are wondering if it comes from cellpose or from torch.
Context
Our goal is to perform segmentation of 3D images with cellpose pre-trained nuclei
model. We need to segment a certain number of arrays (say 20 of them), and each array may have shape like (30, 2160, 2560)
and type uint16
. The processing of different arrays (AKA the different cellpose calls) takes place sequentially, on a node which has 64G of memory and access to a GPU. The GPU memory is under control throughout the entire run (around 4 GiB out of 16 are used), while this issue concerns the standard RAM usage (which we monitor via mprof).
Code and results
As a minimal-working example, we load a single array of shape (30,2160,2560)
and repeatedly compute the corresponding labels several times. If needed, we can find the best way to share the image folder - or use other data which are already easily available for testing.
The code looks like
import sys
import time
from skimage.io import imread
import numpy as np
from cellpose import core
from cellpose import models
def run_cellpose(img, model):
t_start = time.perf_counter()
print(f"START | shape: {img.shape}")
sys.stdout.flush()
mask, flows, styles, diams = model.eval(
img,
do_3D=True,
channels=[0, 0],
net_avg=False,
augment=False,
diameter=80.0,
anisotropy=6.0,
cellprob_threshold=0.0,
)
t_end = time.perf_counter()
print(f"END | num_labels={np.max(mask)}, elapsed_time={t_end-t_start:.3f}")
sys.stdout.flush()
return mask
# Read 3D stack of images (42 Z planes available)
num_z = 30
stack = np.empty((num_z, 2160, 2560), dtype=np.uint16)
for z in range(num_z):
stack[z, :, :] = imread(f"images_v1/20200812-CardiomyocyteDifferentiation14-Cycle1_B05_T0001F002L01A01Z{z+1:02d}C01.png")
# Initialize cellpose
use_gpu = core.use_gpu()
model = models.Cellpose(gpu=use_gpu, model_type="nuclei")
print(f"End of initialization: num_z={num_z}, use_gpu={use_gpu}")
nruns = 10
for run in range(nruns):
print(run)
run_cellpose(stack, model)
This code runs through, and it takes approximately 320 seconds for each segmentation (finding around 3k labels). The memory trace during the first few iterations of the loop is shown below, and we notice that subsequent runs have a larger and larger memory usage - until this saturates after a few iterations. If we look at the plateau regions in the memory trace, for instance, their values (in GiB) are: 12, 13.8, 14.1, 14.1, .. Also the memory-usage peaks at the end of each cellpose calls are shifting up by a similar amount, accumulating about 2 GiB during the first 2-3 iterations. The simplest explanation would be that cellpose or torch are caching something, but we couldn't identify what is being cached. Is this actually happening? If so, is there a way to deactivate this caching mechanism?
Expected behavior and why it matters
We would expect that subsequent runs on the same exact input require a very similar amount of memory - unless some caching is in-place. The relevance of this issue (for us) is that even if the memory accumulation seems mild (that's only 2 GiB more than expected), in more complex/heavy use cases (including additional parallelism) it may lead to memory errors (as we found in https://github.com/fractal-analytics-platform/fractal/issues/109#issuecomment-1198916009). For this reason we'd really like to keep it under control, possibly by deactivating caching options (if any).
Environment
The python code above is submitted to a SLURM queue, and it runs on a node with a GPU available.
Relevant details on the python environment:
sys.version='3.8.13 (default, Mar 28 2022, 11:38:47) \n[GCC 7.5.0]'
numpy.__version__='1.23.1'
torch.__version__='1.12.0+cu102'
I have no idea, have you tried any garbage collecting? you could call cellpose as a process and then it will clean up (all those options are available on the CLI) but then you have to re-read in the saved masks
Thanks for your comment.
I confirm that adding gc.collect()
here and there (both within run_cellpose
function, and especially right after each call to this function within the loop) does not lead to any relevant change in the memory trace.
At the moment we cannot go for the CLI path, since this labeling task is part of a more complex platform to process bio-images (https://github.com/fractal-analytics-platform/fractal), where tasks need to be python functions.
For now we'll just keep this issue in mind, and apply mitigation strategies (e.g. working at a lower resolution) if/when needed.