RobustSAM icon indicating copy to clipboard operation
RobustSAM copied to clipboard

weird ouput from SamPredictor

Open rongjhan opened this issue 1 year ago • 2 comments

i implement predict in two ways one is like demo in eval.py in repository , the other one is use Sampredictor but i get different output I think Sampredictor should transform ndarray to tensor first , then just all the same operation after that i can't tell where i make mistake
image image

import numpy as np
import torch.types
from robust_segment_anything import SamPredictor, sam_model_registry
from robust_segment_anything.utils.transforms import ResizeLongestSide 
from PIL import Image
from typing import Literal



ckpt_cache = r"path\to\robustsam_checkpoint_b.pth"
img_path =  r"path\to\rain.jpg"   # rain.jpg can be download from demo_images 
input_boxes =[[2.47010520553918, 1.390010846867655, 267.41756993667735, 296.76731580624437]] 

def sam_robust_predict(
    raw_image: Image.Image,
    input_points = None,
    input_labels = None,
    input_boxes = None,
    model_size: Literal["base", "large", "huge"] = "large",
    device: Literal["cpu", "cuda"] = "cpu",
) -> Image.Image:


    model = sam_model_registry[f"vit_{model_size[0:1]}"](None, checkpoint=ckpt_cache)
    sam_transform = ResizeLongestSide(model.image_encoder.img_size)
    model = model.to(device)

    if raw_image.mode != "RGB":
        raw_image = raw_image.convert("RGB")

    data_dict = {}

    #trasnform image
    image = np.array(raw_image, dtype=np.uint8)
    image_t = torch.tensor(image, dtype=torch.uint8).unsqueeze(0).to(device)
    image_t = torch.permute(image_t, (0, 3, 1, 2))
    image_t_transformed = sam_transform.apply_image_torch(image_t.float())
    
    data_dict['image'] = image_t_transformed


    # prompt
    np_input_boxes = np.array(input_boxes) if input_boxes else None
    np_input_points = np.array(input_points) if input_points else None
    np_input_labels = np.array(input_labels) if input_labels else None


    #handle box prompt
    if np_input_boxes is not None:
        box_t = torch.Tensor(input_boxes).unsqueeze(0).to(device)
        data_dict['boxes'] = sam_transform.apply_boxes_torch(box_t, image_t.shape[-2:]).unsqueeze(0)


    #handle point prompt
    if np_input_points is not None:
        input_label = torch.Tensor(np_input_labels).to(device)
        point_t = torch.Tensor(np_input_points).to(device)
        data_dict['point_coords'] = sam_transform.apply_coords_torch(point_t, image_t.shape[-2:]).unsqueeze(0)
        data_dict['point_labels'] = input_label.unsqueeze(0)

    data_dict['original_size'] = image_t.shape[-2:]
    with torch.no_grad():   
        batched_output = model.predict(None, [data_dict], multimask_output=False, return_logits=False) 

    output_mask = batched_output[0]['masks']
    h, w = output_mask.shape[-2:]
    img = Image.fromarray(output_mask.reshape(h, w).numpy().astype(np.uint8)*255)
    img.show()
    return img


def sam_robust_predict2(
    raw_image: Image.Image,
    input_points = None,
    input_labels = None,
    input_boxes = None,
    model_size: Literal["base", "large", "huge"] = "base",
    device: Literal["cpu", "cuda"] = "cpu",
) -> Image.Image:

    sam = sam_model_registry[f"vit_{model_size[0:1]}"](None, checkpoint=ckpt_cache)

    sam.eval()
    predictor = SamPredictor(sam)

    if raw_image.mode != "RGB":
        raw_image = raw_image.convert("RGB")

    predictor.set_image(np.array(raw_image, dtype=np.uint8))

    # prompt
    np_input_boxes = np.array(input_boxes) if input_boxes else None
    np_input_points = np.array(input_points) if input_points else None
    np_input_labels = np.array(input_labels) if input_labels else None

    masks, scores, logits = predictor.predict(
        point_coords=np_input_points,
        point_labels=np_input_labels,
        box=np_input_boxes,
        multimask_output=False,
        # return_logits=True
    )

    h, w = masks.shape[-2:] 
    img = Image.fromarray(masks.reshape(h, w).numpy().astype(np.uint8)*255)
    img.show()
    return img


sam_robust_predict(
    Image.open(img_path),
    None,
    None,
    input_boxes=input_boxes,
    model_size="base",
)

sam_robust_predict2(
    Image.open(img_path),
    None,
    None,
    input_boxes=input_boxes,
    model_size="base",
)

rongjhan avatar Aug 22 '24 17:08 rongjhan