segment-anything icon indicating copy to clipboard operation
segment-anything copied to clipboard

ONNX model produces worse result than Pytorch counterpart

Open james-imi opened this issue 1 year ago • 1 comments

So I have the following for the Pytorch prediction for a finetuned model with only bounding boxes.

Pytorch Prediction

bbox = [1055,  412, 1286,  991]
bbox = np.array(bbox)

predictor = SamPredictor(sam)
predictor.set_image(image)
masks, scores, logits = predictor.predict(
    box=bbox,
    multimask_output=False,
)
plt.imshow(masks[0])

and get the image like this (which is correct) image

ONNX Prediction

# Convert to ONNX
onnx_model = SamOnnxModel(sam, return_single_mask=True)

embed_size = sam.prompt_encoder.image_embedding_size
dummy_inputs = {
    "image_embeddings": torch.randn(1, sam.prompt_encoder.embed_dim, *embed_size, dtype=torch.float),
    "point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float),
    "point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float),
    "mask_input": torch.randn(1, 1, *([4 * x for x in embed_size]), dtype=torch.float),
    "has_mask_input": torch.tensor([1], dtype=torch.float),
    "orig_im_size": torch.tensor([1500, 2250], dtype=torch.float),
}

with open(model_output_path, "wb") as f:
    torch.onnx.export(
        onnx_model,
        tuple(dummy_inputs.values()),
        f,
        export_params=True,
        verbose=False,
        opset_version=15,
        do_constant_folding=True,
        input_names=list(dummy_inputs.keys()),
        output_names=["masks", "iou_predictions", "low_res_masks"],
        dynamic_axes={
            "point_coords": {1: "num_points"}, "point_labels": {1: "num_points"},
        },
    )  

# Start onnx
ort_session = onnxruntime.InferenceSession(model_output_path)

# Encode bounding box
onnx_box_coords = input_box.reshape(2, 2)
onnx_box_labels = np.array([2, 3])

onnx_coord = onnx_box_coords[None, :, :]
onnx_label = onnx_box_labels[None, :].astype(np.float32)
onnx_coord = predictor.transform.apply_coords(onnx_coord, image.shape[:2]).astype(np.float32)

onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
onnx_has_mask_input = np.array([0], dtype=np.float32)

ort_inputs = {
    "image_embeddings": image_embedding,
    "point_coords": onnx_coord,
    "point_labels": onnx_label,
    "mask_input": onnx_mask_input,
    "has_mask_input": onnx_has_mask_input,
    "orig_im_size": np.array(image.shape[:2], dtype=np.float32)
}

# Predict
masks, _, _ = ort_session.run(None, ort_inputs)
masks = masks > predictor.model.mask_threshold
plt.imshow(masks[0][0])

and get this wrong one image

Any possible idea why?

james-imi avatar Mar 20 '24 18:03 james-imi

try this again, maybe you can get a sample mask: mask = masks[0][0] mask = (mask > 0).astype('uint8')*255 plt.imshow(mask )

zhangzeyang000 avatar Jun 06 '24 03:06 zhangzeyang000