segment-anything
segment-anything copied to clipboard
Mask generated outside of bounding box? [ point box example]
With code: segment_ai_onxx_model.zip
Trying a point and a box - Why is the mask outside of the bounding box ?
Would you mind posting the code here?
@JordanMakesMaps, it is in the zip file. See the attachment.
Please see the instructions in the onnx model tutorial notebook on how to format the inputs correctly for the onnx model: https://github.com/facebookresearch/segment-anything/blob/main/notebooks/onnx_model_example.ipynb
@nikhilaravi I followed everything but the notebook doesn't explicitly say anything regarding the format of the inputs. I also adjusted the orig_im_size and still get the same error - a mask outside the box.
Try it:
`---------------------------------------------------------------------------
import os #os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb=1024"
import warnings import torch #needs installation import torchvision # needs installation
from PIL import Image
import numpy as np import cv2 import matplotlib.pyplot as plt from segment_anything import sam_model_registry, SamPredictor from segment_anything.utils.onnx import SamOnnxModel
import onnxruntime from onnxruntime.quantization import QuantType from onnxruntime.quantization.quantize import quantize_dynamic
def show_mask(mask, ax): color = np.array([30/255, 144/255, 255/255, 0.6]) h, w = mask.shape[-2:] mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) ax.imshow(mask_image)
def show_points(coords, labels, ax, marker_size=375): pos_points = coords[labels==1] neg_points = coords[labels==0] ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='', s=marker_size, edgecolor='white', linewidth=1.25) ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='', s=marker_size, edgecolor='white', linewidth=1.25)
def show_box(box, ax): x0, y0 = box[0], box[1] w, h = box[2] - box[0], box[3] - box[1] ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
def clear_cuda_memory(): """ Clears GPU memory by deleting all allocated tensors and performing a garbage collection. """ # Check if GPU is available if torch.cuda.is_available(): # Set the current device to the GPU torch.cuda.set_device(0)
# Release all cached memory
torch.cuda.empty_cache()
# Perform a manual garbage collection to ensure that all unused tensors are removed
import gc
gc.collect()
Call the clear_cuda_memory function at the beginning of the script
clear_cuda_memory()
Export an ONNX model
Set the path below to a SAM model checkpoint, then load the model.
This will be needed to both export the model and to calculate
embeddings for the model.
checkpoint = "./sam_vit_h_4b8939.pth" model_type = "vit_h"
Open the image file
image_path = "./dog.jpg" image = Image.open(image_path)
Get the size (dimensions) of the image
width, height = image.size
sam = sam_model_registrymodel_type
""" The script segment-anything/scripts/export_onnx_model.py can be used to export the necessary portion of SAM. Alternatively, run the following code to export an ONNX model. If you have already exported a model, set the path below and skip to the next section. Assure that the exported ONNX model aligns with the checkpoint and model type set above. This notebook expects the model was exported with the parameter return_single_mask=True. """
onnx_model_path = None # Set to use an already exported model, then skip to the next section.
onnx_model_path = "./sam_onnx_example.onnx"
onnx_model = SamOnnxModel(sam, return_single_mask=True)
dynamic_axes = { "point_coords": {1: "num_points"}, "point_labels": {1: "num_points"}, }
embed_dim = sam.prompt_encoder.embed_dim embed_size = sam.prompt_encoder.image_embedding_size mask_input_size = [4 * x for x in embed_size] dummy_inputs = { "image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float), "point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float), "point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float), "mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float), "has_mask_input": torch.tensor([1], dtype=torch.float), "orig_im_size": torch.tensor([height, width], dtype=torch.float), } output_names = ["masks", "iou_predictions", "low_res_masks"]
with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) warnings.filterwarnings("ignore", category=UserWarning) with open(onnx_model_path, "wb") as f: torch.onnx.export( onnx_model, tuple(dummy_inputs.values()), f, export_params=True, verbose=False, opset_version=17, do_constant_folding=True, input_names=list(dummy_inputs.keys()), output_names=output_names, dynamic_axes=dynamic_axes, )
""" If desired, the model can additionally be quantized and optimized. We find this improves web runtime significantly for negligible change in qualitative performance. Run the next cell to quantize the model, or skip to the next section otherwise. """ onnx_model_quantized_path = "sam_onnx_quantized_example.onnx" quantize_dynamic( model_input=onnx_model_path, model_output=onnx_model_quantized_path, optimize_model=True, per_channel=False, reduce_range=False, weight_type=QuantType.QUInt8, )
onnx_model_path = onnx_model_quantized_path
EXAMPLE
image = cv2.imread(image_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
plt.figure(figsize=(10,10)) plt.imshow(image) plt.axis('on') plt.show()
USING AN ONNX MODEL
""" Here as an example, we use onnxruntime in python on CPU to execute the ONNX model. However, any platform that supports an ONNX runtime could be used in principle. Launch the runtime session below: """
ort_session = onnxruntime.InferenceSession(onnx_model_path)
""" To use the ONNX model, the image must first be pre-processed using the SAM image encoder. This is a heavier weight process best performed on GPU. SamPredictor can be used as normal, then .get_image_embedding() will retreive the intermediate features. """ sam.to(device='cuda') predictor = SamPredictor(sam)
predictor.set_image(image)
image_embedding = predictor.get_image_embedding().cpu().numpy()
print(image_embedding.shape)
""" The ONNX model has a different input signature than SamPredictor.predict. The following inputs must all be supplied. Note the special cases for both point and mask inputs. All inputs are np.float32.
image_embeddings: The image embedding from predictor.get_image_embedding(). Has a batch index of length 1. point_coords: Coordinates of sparse input prompts, corresponding to both point inputs and box inputs. Boxes are encoded using two points, one for the top-left corner and one for the bottom-right corner. Coordinates must already be transformed to long-side 1024. Has a batch index of length 1. point_labels: Labels for the sparse input prompts. 0 is a negative input point, 1 is a positive input point, 2 is a top-left box corner, 3 is a bottom-right box corner, and -1 is a padding point. If there is no box input, a single padding point with label -1 and coordinates (0.0, 0.0) should be concatenated. mask_input: A mask input to the model with shape 1x1x256x256. This must be supplied even if there is no mask input. In this case, it can just be zeros. has_mask_input: An indicator for the mask input. 1 indicates a mask input, 0 indicates no mask input. orig_im_size: The size of the input image in (H,W) format, before any transformation. Additionally, the ONNX model does not threshold the output mask logits. To obtain a binary mask, threshold at sam.mask_threshold (equal to 0.0). """ ############################
EXAMPLE SINGLE POINT INPUT
############################ input_point = np.array([[605, 205]]) input_label = np.array([1]) # 1 (foreground point) or 0 (background point)
Add a batch index, concatenate a padding point, and transform.
onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[None, :, :] onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[None, :].astype(np.float32)
onnx_coord = predictor.transform.apply_coords(onnx_coord, image.shape[:2]).astype(np.float32)
Create an empty mask input and an indicator for no mask.
onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32) onnx_has_mask_input = np.zeros(1, dtype=np.float32)
Package the inputs to run in the onnx model
ort_inputs = { "image_embeddings": image_embedding, "point_coords": onnx_coord, "point_labels": onnx_label, "mask_input": onnx_mask_input, "has_mask_input": onnx_has_mask_input, "orig_im_size": np.array(image.shape[:2], dtype=np.float32) }
Predict a mask and threshold it.
masks, _, low_res_logits = ort_session.run(None, ort_inputs) masks = masks > predictor.model.mask_threshold
print(masks.shape)
plt.figure(figsize=(10,10)) plt.imshow(image) show_mask(masks, plt.gca()) show_points(input_point, input_label, plt.gca()) plt.axis('off') plt.show()
NOW THAT WE HAVE A MASK CREATED BY A SINGLE POINT WE USE IT
ALONG WITH A NEW SET OF POINTS
LOAD THE PREVIOUS MASK
MASK INPUT EXAMPLE:
input_point = np.array([[600, 200], [560, 200]]) input_label = np.array([1, 1])
Use the mask output from the previous run. It is already in the correct form for input to the ONNX model.
onnx_mask_input = low_res_logits
Transform the points as in the previous example.
onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[None, :, :] onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[None, :].astype(np.float32)
onnx_coord = predictor.transform.apply_coords(onnx_coord, image.shape[:2]).astype(np.float32)
The has_mask_input indicator is now 1. (because we have a real mask)
onnx_has_mask_input = np.ones(1, dtype=np.float32)
Package inputs, then predict and threshold the mask.
ort_inputs = { "image_embeddings": image_embedding, "point_coords": onnx_coord, "point_labels": onnx_label, "mask_input": onnx_mask_input, "has_mask_input": onnx_has_mask_input, "orig_im_size": np.array(image.shape[:2], dtype=np.float32) }
masks, _, _ = ort_session.run(None, ort_inputs) masks = masks > predictor.model.mask_threshold
plt.figure(figsize=(10,10)) plt.imshow(image) show_mask(masks, plt.gca()) show_points(input_point, input_label, plt.gca()) plt.axis('off') plt.show()
EXAMPLE USING A BOX AND A POINT
input_box = np.array([490, 290, 670, 130]) input_point = np.array([[550, 200]]) input_label = np.array([0])
Add a batch index, concatenate a box and point inputs, add the
appropriate labels for the box corners, and transform. There is no
padding point since the input includes a box input.
onnx_box_coords = input_box.reshape(2, 2) onnx_box_labels = np.array([2,3])
onnx_coord = np.concatenate([input_point, onnx_box_coords], axis=0)[None, :, :] onnx_label = np.concatenate([input_label, onnx_box_labels], axis=0)[None, :].astype(np.float32)
onnx_coord = predictor.transform.apply_coords(onnx_coord, image.shape[:2]).astype(np.float32)
Package inputs, then predict and threshold the mask.
onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32) onnx_has_mask_input = np.zeros(1, dtype=np.float32)
ort_inputs = { "image_embeddings": image_embedding, "point_coords": onnx_coord, "point_labels": onnx_label, "mask_input": onnx_mask_input, "has_mask_input": onnx_has_mask_input, "orig_im_size": np.array(image.shape[:2], dtype=np.float32) }
masks, _, _ = ort_session.run(None, ort_inputs) masks = masks > predictor.model.mask_threshold
plt.figure(figsize=(10, 10)) plt.imshow(image) show_mask(masks[0], plt.gca()) show_box(input_box, plt.gca()) show_points(input_point, input_label, plt.gca()) plt.axis('off') plt.show()
`
@JordanMakesMaps see code
if only use box prompt,how to do it?