Grounded-Segment-Anything icon indicating copy to clipboard operation
Grounded-Segment-Anything copied to clipboard

get_grounding_output and predict_with_classes return different results

Open mhyeonsoo opened this issue 1 year ago • 2 comments

Hi,

I noticed that new module which is using class list as a input of prompt has been newly added to groundingDino. I tried both get_grounding_output and predict_with_classes, and could see the reults are different.

1. get_grounding_output

def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"):
    caption = caption.lower()
    caption = caption.strip()
    if not caption.endswith("."):
        caption = caption + "."
    model = model.to(device)
    image = image.to(device)
    with torch.no_grad():
        outputs = model(image[None], captions=[caption])
    logits = outputs["pred_logits"].cpu().sigmoid()[0]  # (nq, 256)
    boxes = outputs["pred_boxes"].cpu()[0]  # (nq, 4)
    logits.shape[0]

    # filter output
    logits_filt = logits.clone()
    boxes_filt = boxes.clone()
    filt_mask = logits_filt.max(dim=1)[0] > box_threshold
    logits_filt = logits_filt[filt_mask]  # num_filt, 256
    boxes_filt = boxes_filt[filt_mask]  # num_filt, 4
    logits_filt.shape[0]

    # get phrase
    tokenlizer = model.tokenizer
    tokenized = tokenlizer(caption)
    # build pred
    pred_phrases = []
    for logit, box in zip(logits_filt, boxes_filt):
        pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
        if with_logits:
            pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
        else:
            pred_phrases.append(pred_phrase)

    return boxes_filt, pred_phrases
TEXT_PROMPT = "zucchini . carrot . asparagus"
BOX_TRESHOLD = 0.30
TEXT_TRESHOLD = 0.25

image_pil, image = load_image(local_image_path)

# run grounding dino model
boxes_filt, pred_phrases = get_grounding_output(
    groundingdino_model, image, TEXT_PROMPT, BOX_TRESHOLD, TEXT_TRESHOLD, device=device
)

image

2. predict_with_classes

def predict_with_classes(
        self,
        image: np.ndarray,
        classes: List[str],
        box_threshold: float,
        text_threshold: float
    ) -> sv.Detections:
        """
        import cv2

        image = cv2.imread(IMAGE_PATH)

        model = Model(model_config_path=CONFIG_PATH, model_checkpoint_path=WEIGHTS_PATH)
        detections = model.predict_with_classes(
            image=image,
            classes=CLASSES,
            box_threshold=BOX_THRESHOLD,
            text_threshold=TEXT_THRESHOLD
        )


        import supervision as sv

        box_annotator = sv.BoxAnnotator()
        annotated_image = box_annotator.annotate(scene=image, detections=detections)
        """
        caption = ". ".join(classes)
        processed_image = Model.preprocess_image(image_bgr=image).to(self.device)
        boxes, logits, phrases = predict(
            model=self.model,
            image=processed_image,
            caption=caption,
            box_threshold=box_threshold,
            text_threshold=text_threshold,
            device=self.device)
        source_h, source_w, _ = image.shape
        detections = Model.post_process_result(
            source_h=source_h,
            source_w=source_w,
            boxes=boxes,
            logits=logits)
        class_id = Model.phrases2classes(phrases=phrases, classes=classes)
        detections.class_id = class_id
        return detections
SOURCE_IMAGE_PATH = 'sample.jpg'
CLASSES = ["zucchini", "carrot", "asparagus"]
BOX_TRESHOLD = 0.30
TEXT_TRESHOLD = 0.25

# detect objects
detections = grounding_dino_model.predict_with_classes(
    image=image,
    classes=enhance_class_name(class_names=CLASSES),
    box_threshold=BOX_TRESHOLD,
    text_threshold=TEXT_TRESHOLD
)

image

Since SAM module is all same for both versions, I guess two DINO parts return different outputs. Is there a reason for this issue? Can I make the class list input to work at least as well as get_grounding_output module?

Thanks,

mhyeonsoo avatar May 09 '23 04:05 mhyeonsoo

We will check the input prompt for different API later, maybe there's are something different in preprocess of the language prompt~

rentainhe avatar May 09 '23 05:05 rentainhe

@rentainhe Thanks for your quick response! I will be happy to wait for the updates :)

mhyeonsoo avatar May 09 '23 05:05 mhyeonsoo