segment-anything
segment-anything copied to clipboard
ONNX model produces worse result than Pytorch counterpart
So I have the following for the Pytorch prediction for a finetuned model with only bounding boxes.
Pytorch Prediction
bbox = [1055, 412, 1286, 991]
bbox = np.array(bbox)
predictor = SamPredictor(sam)
predictor.set_image(image)
masks, scores, logits = predictor.predict(
box=bbox,
multimask_output=False,
)
plt.imshow(masks[0])
and get the image like this (which is correct)
ONNX Prediction
# Convert to ONNX
onnx_model = SamOnnxModel(sam, return_single_mask=True)
embed_size = sam.prompt_encoder.image_embedding_size
dummy_inputs = {
"image_embeddings": torch.randn(1, sam.prompt_encoder.embed_dim, *embed_size, dtype=torch.float),
"point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float),
"point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float),
"mask_input": torch.randn(1, 1, *([4 * x for x in embed_size]), dtype=torch.float),
"has_mask_input": torch.tensor([1], dtype=torch.float),
"orig_im_size": torch.tensor([1500, 2250], dtype=torch.float),
}
with open(model_output_path, "wb") as f:
torch.onnx.export(
onnx_model,
tuple(dummy_inputs.values()),
f,
export_params=True,
verbose=False,
opset_version=15,
do_constant_folding=True,
input_names=list(dummy_inputs.keys()),
output_names=["masks", "iou_predictions", "low_res_masks"],
dynamic_axes={
"point_coords": {1: "num_points"}, "point_labels": {1: "num_points"},
},
)
# Start onnx
ort_session = onnxruntime.InferenceSession(model_output_path)
# Encode bounding box
onnx_box_coords = input_box.reshape(2, 2)
onnx_box_labels = np.array([2, 3])
onnx_coord = onnx_box_coords[None, :, :]
onnx_label = onnx_box_labels[None, :].astype(np.float32)
onnx_coord = predictor.transform.apply_coords(onnx_coord, image.shape[:2]).astype(np.float32)
onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
onnx_has_mask_input = np.array([0], dtype=np.float32)
ort_inputs = {
"image_embeddings": image_embedding,
"point_coords": onnx_coord,
"point_labels": onnx_label,
"mask_input": onnx_mask_input,
"has_mask_input": onnx_has_mask_input,
"orig_im_size": np.array(image.shape[:2], dtype=np.float32)
}
# Predict
masks, _, _ = ort_session.run(None, ort_inputs)
masks = masks > predictor.model.mask_threshold
plt.imshow(masks[0][0])
and get this wrong one
Any possible idea why?
try this again, maybe you can get a sample mask: mask = masks[0][0] mask = (mask > 0).astype('uint8')*255 plt.imshow(mask )