GroundingDINO icon indicating copy to clipboard operation
GroundingDINO copied to clipboard

Logits of inteference

Open Corallo opened this issue 11 months ago • 3 comments

Hi, If I understand correctly, the logits are the feature of the object detected.

I was wondering why in the inference script you extract only the first element, making them unusable, instead of keeping them all. https://github.com/IDEA-Research/GroundingDINO/blob/main/groundingdino/util/inference.py#L97

Corallo avatar Mar 06 '24 16:03 Corallo

The max function returns a listing of the maximum values and the corresponding indices of those max values, at least when using it with dim=1. So (I think) the [0] indexing is used to get only the max values themselves (and not the index of those values), which is likely meant to be used as a confidence 'score' for each of the boxes.

heyoeyo avatar Mar 06 '24 19:03 heyoeyo

@heyoeyo I am not sure I follow, what are outputs["pred_logits"] supposed to be in the first place?

Corallo avatar Mar 08 '24 10:03 Corallo

As far as I understand, the logits are meant to be a numerical representation of info about each of the bounding boxes predicted by the model. They can be thought of as just being an array of (256) numbers, one array for each bounding box.

They seem to use the logits in 3 (similar) ways in that function:

  1. They're used to keep only the 'good' bounding box predictions. A box is considered 'good' if the largest of the (256) logit values is above some box_threshold value.
  2. They're used to figure out which part of the text prompt goes with each box. This happens in the get_phrases_from_posmap(...) function calls. It works in a similar way, where logit values above some text_threshold indicate which part of the input text should be assigned to the box.
  3. They're used as an overall 'score' of the quality of the box + text label (the part that you linked)

Anyways, the easiest way to understand the weird indexing on the output is to try it with some sample data. You can try running something like:

import torch
logits = torch.randint(0,10,(6,2))

print(logits)
# tensor([[5, 7],
#         [0, 9],
#         [2, 0],
#         [9, 7],
#         [6, 6],
#         [9, 0]])

print(logits.max(dim=1)[0])
# tensor([7, 9, 2, 9, 6, 9])

print(logits.max(dim=1)[1])
# tensor([1, 1, 0, 0, 0, 0])

You can see the .max(dim=1)[0] just gives the maximum number along each row. If you do .max(dim=1)[1] you get the column position of each of those max values.

heyoeyo avatar Mar 08 '24 13:03 heyoeyo