Slow performance when stitching thousands of masks
Hi cellpose team, we are trying to use cellpose to segment high density, nuclear labeled, 3D data from zebrafish using a custom model. We use the stitching method to resolve the data in Z as we have varying anisotropy and we obtain better detection than using a full 3D model, however on full dataset cellpose takes a very large amount of time to complete. We have upwards of 20k ROIs, but in each plane there are around 6000 ROIs only.
By running the model.eval in isolation we narrowed the slowdown to the late stage of the stitching (ie the model part returns fine, then the second progress bar stitching x planes using stitch_threshold=y also completes). Do you have any idea which part of the process might be responsible for the slow performance?
We are using cellpose version 3.1.0
I just created a PR (https://github.com/MouseLand/cellpose/pull/1114) to address one of these slowdowns with large datasets. One of the steps in 3D using the stitching/IoU method is creating an overlay table where all the masks from z-slice x+1 and z-slice x are checked for overlapping labels. This is currently done on a pixel-by-pixel basis using a for-loop, which can take a lot of time. There are some alternative operations that can remove this need for a for-loop, such as the one proposed in the PR. For me, it made a huge difference, so let's see!
This is currently done on a pixel-by-pixel basis using a for-loop, which can take a lot of time.
ouch, I can see how it can happen. I'll take a look to see if your also PR helps us and report if I find other slowdowns. Cheers
@Tomvl117 weird, implementing your change seems to lead to a crash on our end, as the function where the change is gets actually translated by numba.jit(nopython=True) which does not like np.add.at
Also, it seems that our performance woes are after utils.stitch3D call at line 636 in models.py, which leads to _label_overlap, as the loop terminates and the call returns in our case, so I'm guessing that the next line utils.fill_holes_and_remove_small_masks is the culprit.
Indeed I gave a quick look and it seems to do many binary morpholgy operations over the masks and seems to be written in mostly python without numba.
I will further investigate next week tho, thanks for the help in the meantime 👋
Oh that makes sense, and that also explains why I was running into issues when I was calling the function "out of context", since the Numba acceleration then wouldn't apply. I suppose it requires some more testing.
As for fill_holes_and_remove_small_masks, it uses scipy.ndimage.morphology.binary_fill_holes for many of its operations, which can get very slow for many masks. There is an alternative approach out there called fill_voids, which I've also opened a PR for, but I guess I should do some more rigorous testing to see if it's compatible with Numba.
I can report that removing utils.fill_holes_and_remove_small_masks leads to acceptable performance in our case without loss of quality, so for now we can work with that.
However, in a similar vein, we also get a problem when trying to run the full 3D segmentation on those large datasets. The model runs fine in all 3 planes, possibly because it is batched. After that however there seems to be a refinement step that is not batched an we run in a memory error with CUDA. So it's either line 256 or line 533. Does someone have any idea if those can be easily batched? or should i force those steps to be run on the CPU?