segment-anything
segment-anything copied to clipboard
Improve Segmentation using Bounding Boxes
Is there a way to improve SamPredictor segmentation when using bounding boxes?
I have something I want to segment—fungal colonies that are growing into each other in petri dish—and the predictive model doesn't do the greatest job of segmenting it:
This is relatively easy to do with a petri dish with separated colonies like this:
And the following script successfully segments the above petri dish with separated colonies:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Jan 25 20:45:01 2024
@author: lemurbear
"""
# import modules
import cv2
import numpy as np
import torch
import torchvision
print("PyTorch version:", torch.__version__)
print("Torchvision version:", torchvision.__version__)
print("CUDA is available:", torch.cuda.is_available())
from segment_anything \
import sam_model_registry, SamPredictor
import supervision as sv
# use code if available
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
# import sam model
sam_checkpoint = "/Users/lemurbear/Downloads/sam_vit_h_4b8939.pth"
model_type = "vit_h"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
mask_predictor = SamPredictor(sam)
# import image and set image with mask_predictor
predictive_img_path = "/Users/lemurbear/Downloads/20231031_ERG24_C5_8.jpg"
image_bgr = cv2.imread(predictive_img_path)
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
mask_predictor.set_image(image_rgb)
# define bounding boxes
box_01 = {'x': 370, 'y': 210, 'width': 175, 'height': 160, 'label': ''}
box_02 = {'x': 370, 'y': 317, 'width': 175, 'height': 160, 'label': ''}
box_03 = {'x': 370, 'y': 480, 'width': 175, 'height': 160, 'label': ''}
box_04 = {'x': 370, 'y': 560, 'width': 175, 'height': 160, 'label': ''}
box_05 = {'x': 370, 'y': 670, 'width': 175, 'height': 160, 'label': ''}
box_06 = {'x': 370, 'y': 811, 'width': 175, 'height': 160, 'label': ''}
box_07 = {'x': 370, 'y': 940, 'width': 175, 'height': 160, 'label': ''}
box_08 = {'x': 370, 'y': 1070, 'width': 175, 'height': 160, 'label': ''}
box_dict = {'box_01': box_01, 'box_02': box_02, 'box_03': box_03,
'box_04': box_04, 'box_05': box_05, 'box_06': box_06,
'box_07': box_07, 'box_08': box_08}
# assign bounding boxes to 'boxes' array
box_num = 0
boxes = {}
for this_box in box_dict:
box_array = np.array([
box_dict[this_box]['x'],
box_dict[this_box]['y'],
box_dict[this_box]['x'] + box_dict[this_box]['width'],
box_dict[this_box]['y'] + box_dict[this_box]['height']
])
boxes[box_num] = box_array
box_num = box_num + 1
print('colony_' + str('{0:0=2d}'.format(box_num) + ' :' + str(box_dict[this_box])))
print(boxes)
box_annotator = sv.BoxAnnotator(color=sv.Color.red())
mask_annotator = sv.MaskAnnotator(color=sv.Color.red(), color_lookup=sv.ColorLookup.INDEX)
# magics
box_num = 1
petri_areas = {}
for this_box in boxes:
box_name = 'box_' + str('{0:0=2d}'.format(box_num))
mask_name = 'masks_' + str('{0:0=2d}'.format(box_num))
score_name = 'scores_' + str('{0:0=2d}'.format(box_num))
logit_name = 'logit_' + str('{0:0=2d}'.format(box_num))
image_name = 'segmented_image_' + str('{0:0=2d}'.format(box_num) + '.jpg')
print(box_name, mask_name, score_name, logit_name, boxes[this_box])
box_num = box_num + 1
masks_this, scores_this, logits_this = mask_predictor.predict(
box=boxes[this_box],
multimask_output=True
)
detections_this = sv.Detections(
xyxy=sv.mask_to_xyxy(masks=masks_this),
mask=masks_this
)
detections_this = detections_this[detections_this.area == np.max(detections_this.area)]
area_this = str(round((detections_this.area[0] / (28.346*5)**2), 3))
print("Colony area: ", area_this)
source_image = box_annotator.annotate(scene=image_bgr.copy(),
detections=detections_this, skip_label=True)
segmented_image = mask_annotator.annotate(scene=image_bgr.copy(), detections=detections_this)
sv.plot_images_grid(
images=[source_image, segmented_image],
grid_size=(1, 2),
titles=['source image', 'segmented image']
)
cv2.imwrite(image_name, cv2.cvtColor(segmented_image, cv2.COLOR_RGB2BGR))
sv.plot_images_grid(
images=masks_this,
grid_size=(1, 4),
size=(16, 4)
)
If I change the box locations, it does an okay job with the first petri dish, but it is still patchy.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Jan 25 20:45:01 2024
@author: lemurbear
"""
import cv2
import numpy as np
import torch
import torchvision
print("PyTorch version:", torch.__version__)
print("Torchvision version:", torchvision.__version__)
print("CUDA is available:", torch.cuda.is_available())
from segment_anything \
import sam_model_registry, SamPredictor
import supervision as sv
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
sam_checkpoint = "/Users/lemurbear/Downloads/sam_vit_h_4b8939.pth"
model_type = "vit_h"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
predictive_img_path = "/Users/lemurbear/Downloads/20231212_085406.jpg"
mask_predictor = SamPredictor(sam)
image_bgr = cv2.imread(predictive_img_path)
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
mask_predictor.set_image(image_rgb)
############ writing for loop(s) ##############
box_01 = {'x': 360, 'y': 200, 'width': 230, 'height': 170, 'label': ''}
box_02 = {'x': 370, 'y': 317, 'width': 185, 'height': 160, 'label': ''}
box_03 = {'x': 340, 'y': 480, 'width': 230, 'height': 160, 'label': ''}
box_04 = {'x': 330, 'y': 570, 'width': 220, 'height': 160, 'label': ''}
box_05 = {'x': 330, 'y': 670, 'width': 210, 'height': 160, 'label': ''}
box_06 = {'x': 310, 'y': 811, 'width': 270, 'height': 160, 'label': ''}
box_07 = {'x': 300, 'y': 950, 'width': 280, 'height': 160, 'label': ''}
box_08 = {'x': 325, 'y': 1075, 'width': 220, 'height': 200, 'label': ''}
box_dict = {'box_01': box_01, 'box_02': box_02, 'box_03': box_03,
'box_04': box_04, 'box_05': box_05, 'box_06': box_06,
'box_07': box_07, 'box_08': box_08}
# do a for loop for assigning box_01, box_02... to boxes array
box_num = 0
boxes = {}
for this_box in box_dict:
box_array = np.array([
box_dict[this_box]['x'],
box_dict[this_box]['y'],
box_dict[this_box]['x'] + box_dict[this_box]['width'],
box_dict[this_box]['y'] + box_dict[this_box]['height']
])
boxes[box_num] = box_array
box_num = box_num + 1
print('colony_' + str('{0:0=2d}'.format(box_num) + ' :' + str(box_dict[this_box])))
print(boxes)
box_annotator = sv.BoxAnnotator(color=sv.Color.red())
mask_annotator = sv.MaskAnnotator(color=sv.Color.red(), color_lookup=sv.ColorLookup.INDEX)
box_num = 1
petri_areas = {}
for this_box in boxes:
box_name = 'box_' + str('{0:0=2d}'.format(box_num))
mask_name = 'masks_' + str('{0:0=2d}'.format(box_num))
score_name = 'scores_' + str('{0:0=2d}'.format(box_num))
logit_name = 'logit_' + str('{0:0=2d}'.format(box_num))
image_name = 'segmented_image_' + str('{0:0=2d}'.format(box_num) + '.jpg')
print(box_name, mask_name, score_name, logit_name, boxes[this_box])
box_num = box_num + 1
masks_this, scores_this, logits_this = mask_predictor.predict(
box=boxes[this_box],
multimask_output=True
)
detections_this = sv.Detections(
xyxy=sv.mask_to_xyxy(masks=masks_this),
mask=masks_this
)
detections_this = detections_this[detections_this.area == np.max(detections_this.area)]
area_this = str(round((detections_this.area[0] / (28.346*5)**2), 3))
print("Colony area: ", area_this)
# petri_areas =
source_image = box_annotator.annotate(scene=image_bgr.copy(),
detections=detections_this, skip_label=True)
segmented_image = mask_annotator.annotate(scene=image_bgr.copy(), detections=detections_this)
sv.plot_images_grid(
images=[source_image, segmented_image],
grid_size=(1, 2),
titles=['source image', 'segmented image']
)
cv2.imwrite(image_name, cv2.cvtColor(segmented_image, cv2.COLOR_RGB2BGR))
sv.plot_images_grid(
images=masks_this,
grid_size=(1, 4),
size=(16, 4)
)
I'm not sure if there are settings that could improve the quality of the segmentation and would love to hear suggestions.
Many thanks, Emily
There are a couple options that might help.
- Some simple post-processing might be good enough if you just want cleaner looking masks. In particular, morphological filtering (specifically 'closing') can help fill in the gaps in the masks.
You can do this fairly easily using opencv (cv2), I think in your case you could do something like (just after you create the
masks_this
result):
# Set up morphological filter (change ksize to fill in bigger holes)
ksize = (15,15)
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, ksize)
# Apply filter to each mask (as uint8) and convert back to boolean to match original data type
cleaner_masks_this = []
for mask in masks_this:
mask_uint8 = np.uint8(mask) * 255
new_bool_mask = cv2.morphologyEx(mask_uint8, cv2.MORPH_CLOSE, kernel) > 127
cleaner_masks_this.append(new_bool_mask)
-
You can try using the mask input to the predictor. This seems very finicky, but there are posts (see #347 or #169) that suggest you can get better results by iteratively feeding the mask predicted by SAM back into it and re-predicting.
-
If possible, you can try using point prompts instead of the box prompt (i.e. when calling
mask_predictor.predict(...)
). Obviously that's not improving things using boxes, but in case it's an option, it might help. Using both positive and negative prompts tends to help pick up more complex shapes, from what I've seen. Here's an example of the result (using the web demo) on what seemed like the toughest segmentation: -
And lastly, if you're planning to do a lot of this kind of segmentation, then using a variant of SAM that is fine tuned for these kinds of images might help. I don't know anything about this kind of stuff, so I can't be of much help, but a quick search returned CellSam which seems vaguely related, and might be useful? Fine-tuning your own variant could be a lot of work, so it's only worthwhile if you're going to be working with a lot of these images.
Thank you so much! This is really thorough and helpful!
I've been using positive and negative points to help refine the mask as suggested in option 3. I'll probably need to do something more like 4 in the long run because I will be doing this quite a bit, but for now this is working better!
Again, thanks so much!
Cheers, Emily