CellSeg icon indicating copy to clipboard operation
CellSeg copied to clipboard

Mask growth breaks for edge cases

Open MeyerBender opened this issue 10 months ago • 1 comments

Hi,

while investigating the mask growing method, I have come across some unexpected behavior, which looks incorrect to me. For example, notice how the mask on the left side of the image occupies pixels that overlap with other cells from the original segmentation.

Original image: image

Image grown by 1px: image

I have extracted the corresponding code snippets from the CVMask class to create this standalone example for testing:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from skimage.morphology import disk, dilation
from scipy.ndimage.morphology import binary_dilation
from sklearn.neighbors import kneighbors_graph
from scipy.spatial.distance import cdist

# adapted from CVMask
def compute_centroids(flatmasks):
    masks = flatmasks
    num_masks = len(np.unique(masks)) - 1
    indices = np.where(masks != 0)
    values = masks[indices[0], indices[1]]

    maskframe = pd.DataFrame(np.transpose(np.array([indices[0], indices[1], values]))).rename(columns = {0:"x", 1:"y", 2:"id"})
    centroids = maskframe.groupby('id').agg({'x': 'mean', 'y': 'mean'}).to_records(index = False).tolist()

    return centroids
    
# adapted from CVMask
def remove_overlaps_nearest_neighbors(centroids, masks):
        final_masks = np.max(masks, axis = 2)
        collisions = np.nonzero(np.sum(masks > 0, axis = 2) > 1)
        collision_masks = masks[collisions]
        collision_index = np.nonzero(collision_masks)
        collision_masks = collision_masks[collision_index]
        collision_frame = pd.DataFrame(np.transpose(np.array([collision_index[0], collision_masks]))).rename(columns = {0:"collis_idx", 1:"mask_id"})
        grouped_frame = collision_frame.groupby('collis_idx')
        for collis_idx, group in grouped_frame:
            collis_pos = np.expand_dims(np.array([collisions[0][collis_idx], collisions[1][collis_idx]]), axis = 0)
            prevval = final_masks[collis_pos[0,0], collis_pos[0,1]]
            mask_ids = list(group['mask_id'])
            curr_centroids = np.array([centroids[mask_id - 1] for mask_id in mask_ids])
            dists = cdist(curr_centroids, collis_pos)
            closest_mask = mask_ids[np.argmin(dists)]
            final_masks[collis_pos[0,0], collis_pos[0,1]] = closest_mask
        
        return final_masks

# adapted from CVMask
def grow_masks(flatmasks, centroids, growth, method = 'Standard', num_neighbors = 30):
    masks = flatmasks
    num_masks = len(np.unique(masks)) - 1

    # only looking at the standard method, but sequential also appears to have some issues
    if method == 'Standard':
        print("Standard growth selected")
        masks = flatmasks
        num_masks = len(np.unique(masks)) - 1
        indices = np.where(masks != 0)
        values = masks[indices[0], indices[1]]

        maskframe = pd.DataFrame(np.transpose(np.array([indices[0], indices[1], values]))).rename(columns = {0:"x", 1:"y", 2:"id"})
        cent_array = maskframe.groupby('id').agg({'x': 'mean', 'y': 'mean'}).to_numpy()
        connectivity_matrix = kneighbors_graph(cent_array, num_neighbors).toarray() * np.arange(1, num_masks + 1)
        connectivity_matrix = connectivity_matrix.astype(int)
        labels = {}
        for n in range(num_masks):
            connections = list(connectivity_matrix[n, :])
            connections.remove(0)
            layers_used = [labels[i] for i in connections if i in labels]
            layers_used.sort()
            currlayer = 0
            for layer in layers_used:
                if currlayer != layer: 
                    break
                currlayer += 1
            labels[n + 1] = currlayer

        possible_layers = len(list(set(labels.values())))
        label_frame = pd.DataFrame(list(labels.items()), columns = ["maskid", "layer"])
        image_h, image_w = masks.shape
        expanded_masks = np.zeros((image_h, image_w, possible_layers), dtype = np.uint32)

        grouped_frame = label_frame.groupby('layer')
        for layer, group in grouped_frame:
            currids = list(group['maskid'])
            masklocs = np.isin(masks, currids)
            expanded_masks[masklocs, layer] = masks[masklocs]

        dilation_mask = disk(1)
        grown_masks = np.copy(expanded_masks)
        for _ in range(growth):
            for i in range(possible_layers):
                grown_masks[:, :, i] = dilation(grown_masks[:, :, i], dilation_mask)
        return remove_overlaps_nearest_neighbors(centroids, grown_masks)
        
example_data = np.array([[6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 0, 0, 0, 0, 1, 1, 1, 1],
       [6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 0, 0, 0, 1, 1, 1, 1, 1],
       [6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 1, 1, 1, 1, 1, 1],
       [6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 1, 1, 1, 1, 1],
       [2, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 1, 1, 1, 1],
       [2, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 1, 1, 1, 1],
       [2, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 1, 1, 1, 1],
       [2, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 1, 1, 1, 1, 1],
       [2, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 1, 1, 1, 1, 1],
       [2, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 1, 1, 1, 1, 1],
       [2, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 1, 1, 1, 1, 1],
       [2, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 5, 1, 1, 1, 1, 4],
       [2, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 5, 5, 5, 9, 9, 4, 4],
       [2, 2, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 5, 5, 5, 9, 9, 9, 9, 9],
       [2, 0, 0, 7, 7, 7, 7, 7, 7, 7, 5, 5, 5, 5, 9, 9, 9, 9, 9, 9],
       [0, 0, 0, 0, 7, 7, 7, 5, 5, 5, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9],
       [3, 0, 0, 0, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9],
       [3, 3, 0, 0, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 9, 9, 9, 9],
       [3, 3, 3, 0, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 9, 9, 9, 9],
       [3, 3, 8, 8, 8, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 9, 9, 9, 9]])

centroids = compute_centroids(example_data)
masks_grown = grow_masks(example_data, centroids, 1, method = 'Standard', num_neighbors = 8)
plt.imshow(example_data)
plt.show()
plt.imshow(masks_grown)
plt.show()

I would highly appreciate if you could tell me if I am using this method wrong, or if this is actually a bug within the method. Thank you very much in advance!

MeyerBender avatar Apr 08 '24 12:04 MeyerBender