CellSeg
CellSeg copied to clipboard
Mask growth breaks for edge cases
Hi,
while investigating the mask growing method, I have come across some unexpected behavior, which looks incorrect to me. For example, notice how the mask on the left side of the image occupies pixels that overlap with other cells from the original segmentation.
Original image:
Image grown by 1px:
I have extracted the corresponding code snippets from the CVMask class to create this standalone example for testing:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from skimage.morphology import disk, dilation
from scipy.ndimage.morphology import binary_dilation
from sklearn.neighbors import kneighbors_graph
from scipy.spatial.distance import cdist
# adapted from CVMask
def compute_centroids(flatmasks):
masks = flatmasks
num_masks = len(np.unique(masks)) - 1
indices = np.where(masks != 0)
values = masks[indices[0], indices[1]]
maskframe = pd.DataFrame(np.transpose(np.array([indices[0], indices[1], values]))).rename(columns = {0:"x", 1:"y", 2:"id"})
centroids = maskframe.groupby('id').agg({'x': 'mean', 'y': 'mean'}).to_records(index = False).tolist()
return centroids
# adapted from CVMask
def remove_overlaps_nearest_neighbors(centroids, masks):
final_masks = np.max(masks, axis = 2)
collisions = np.nonzero(np.sum(masks > 0, axis = 2) > 1)
collision_masks = masks[collisions]
collision_index = np.nonzero(collision_masks)
collision_masks = collision_masks[collision_index]
collision_frame = pd.DataFrame(np.transpose(np.array([collision_index[0], collision_masks]))).rename(columns = {0:"collis_idx", 1:"mask_id"})
grouped_frame = collision_frame.groupby('collis_idx')
for collis_idx, group in grouped_frame:
collis_pos = np.expand_dims(np.array([collisions[0][collis_idx], collisions[1][collis_idx]]), axis = 0)
prevval = final_masks[collis_pos[0,0], collis_pos[0,1]]
mask_ids = list(group['mask_id'])
curr_centroids = np.array([centroids[mask_id - 1] for mask_id in mask_ids])
dists = cdist(curr_centroids, collis_pos)
closest_mask = mask_ids[np.argmin(dists)]
final_masks[collis_pos[0,0], collis_pos[0,1]] = closest_mask
return final_masks
# adapted from CVMask
def grow_masks(flatmasks, centroids, growth, method = 'Standard', num_neighbors = 30):
masks = flatmasks
num_masks = len(np.unique(masks)) - 1
# only looking at the standard method, but sequential also appears to have some issues
if method == 'Standard':
print("Standard growth selected")
masks = flatmasks
num_masks = len(np.unique(masks)) - 1
indices = np.where(masks != 0)
values = masks[indices[0], indices[1]]
maskframe = pd.DataFrame(np.transpose(np.array([indices[0], indices[1], values]))).rename(columns = {0:"x", 1:"y", 2:"id"})
cent_array = maskframe.groupby('id').agg({'x': 'mean', 'y': 'mean'}).to_numpy()
connectivity_matrix = kneighbors_graph(cent_array, num_neighbors).toarray() * np.arange(1, num_masks + 1)
connectivity_matrix = connectivity_matrix.astype(int)
labels = {}
for n in range(num_masks):
connections = list(connectivity_matrix[n, :])
connections.remove(0)
layers_used = [labels[i] for i in connections if i in labels]
layers_used.sort()
currlayer = 0
for layer in layers_used:
if currlayer != layer:
break
currlayer += 1
labels[n + 1] = currlayer
possible_layers = len(list(set(labels.values())))
label_frame = pd.DataFrame(list(labels.items()), columns = ["maskid", "layer"])
image_h, image_w = masks.shape
expanded_masks = np.zeros((image_h, image_w, possible_layers), dtype = np.uint32)
grouped_frame = label_frame.groupby('layer')
for layer, group in grouped_frame:
currids = list(group['maskid'])
masklocs = np.isin(masks, currids)
expanded_masks[masklocs, layer] = masks[masklocs]
dilation_mask = disk(1)
grown_masks = np.copy(expanded_masks)
for _ in range(growth):
for i in range(possible_layers):
grown_masks[:, :, i] = dilation(grown_masks[:, :, i], dilation_mask)
return remove_overlaps_nearest_neighbors(centroids, grown_masks)
example_data = np.array([[6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 0, 0, 0, 0, 1, 1, 1, 1],
[6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 0, 0, 0, 1, 1, 1, 1, 1],
[6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 1, 1, 1, 1, 1, 1],
[6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 1, 1, 1, 1, 1],
[2, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 1, 1, 1, 1],
[2, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 1, 1, 1, 1],
[2, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 1, 1, 1, 1],
[2, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 1, 1, 1, 1, 1],
[2, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 1, 1, 1, 1, 1],
[2, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 1, 1, 1, 1, 1],
[2, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 1, 1, 1, 1, 1],
[2, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 5, 1, 1, 1, 1, 4],
[2, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 5, 5, 5, 9, 9, 4, 4],
[2, 2, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 5, 5, 5, 9, 9, 9, 9, 9],
[2, 0, 0, 7, 7, 7, 7, 7, 7, 7, 5, 5, 5, 5, 9, 9, 9, 9, 9, 9],
[0, 0, 0, 0, 7, 7, 7, 5, 5, 5, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9],
[3, 0, 0, 0, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9],
[3, 3, 0, 0, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 9, 9, 9, 9],
[3, 3, 3, 0, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 9, 9, 9, 9],
[3, 3, 8, 8, 8, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 9, 9, 9, 9]])
centroids = compute_centroids(example_data)
masks_grown = grow_masks(example_data, centroids, 1, method = 'Standard', num_neighbors = 8)
plt.imshow(example_data)
plt.show()
plt.imshow(masks_grown)
plt.show()
I would highly appreciate if you could tell me if I am using this method wrong, or if this is actually a bug within the method. Thank you very much in advance!