segment-anything icon indicating copy to clipboard operation
segment-anything copied to clipboard

What is the shape of the encoder's boxes?

Open raoxinyu4977 opened this issue 9 months ago • 7 comments

I attempted to set the shape of the encoder input boxes as (4, 10, 4), representing (bs, num_boxes, 2 box corners). However, during the operation:

def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: """Embeds box prompts.""" boxes = boxes + 0.5 # Shift to center of pixel coords = boxes.reshape(-1, 2, 2) corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) corner_embedding[:, 0, :] += self.point_embeddings[2].weight corner_embedding[:, 1, :] += self.point_embeddings[3].weight return corner_embedding

The boxes encoder is reshaped to (bsnum_boxes, 2, 2), outputting (bsnum_boxes, 2, 256). However, in the forward function, the sparse matrix sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) has a shape of (10, 0, 256). When concatenating with sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1), the dimensions don't align.

raoxinyu4977 avatar Apr 28 '24 04:04 raoxinyu4977

The expected shape for box inputs is: Bx4 And if you're using points & labels, the shape should be BxNx2 and BxN respectively, where B is the batch size and N is the number of points. So as-is, the model doesn't support having multiple box prompts for a single mask in the same way that it supports having multiple points (i.e. there is no 'N' component in the box shape).

In theory you could modify the code to support having 'N' boxes, by generating the corner_embedding output for each of the 'N' (10 in the example you gave) boxes and concatenating them together into a single corner_embedding. I think the result should have a shape of Bx(2N)x(embed_dim) to work with the existing code. Though it's unclear how well this would work since it's not part of the original model behavior/training (but worth trying maybe).

heyoeyo avatar Apr 28 '24 13:04 heyoeyo

thanks, I know

raoxinyu4977 avatar Apr 29 '24 08:04 raoxinyu4977

Hello, I've encountered the same issue as you. Could you please share how you resolved it? I would greatly appreciate your assistance. Thank you.

wu2233 avatar Sep 24 '24 09:09 wu2233

There's some code on the SAMv2 issue board that provides support for having multiple boxes for a single prompt. That code references changes to the newer code base, but the equivalent code for SAMv1 can be found in the prompt_encoder script. Though it seems both SAMv2 and SAMv1 perform poorly when using more than 1 box.

heyoeyo avatar Sep 24 '24 16:09 heyoeyo

Thank you very much! I will try.

wu2233 avatar Sep 25 '24 01:09 wu2233

Hello,I've encountered a new issue: For example, in one training batch, I have 2 images; the first image has 2 bounding boxes, and the second image has 3 bounding boxes. In this case, how should I conduct the training? Can the program understand the correspondence between images and bounding boxes within a batch?

wu2233 avatar Sep 25 '24 08:09 wu2233

In general, if you have different shaped data, it would need to be processed in separate batches. In this case if you had multiple images with 2 bounding boxes you could batch all of them together and likewise for images with 3 bounding boxes.

Alternatively, the SAM model includes a not a point embedding that can be used to pad the prompts, so you could use this to make the 2-box prompt tensors the same shape as the 3-box prompts.

Each box prompt adds two 'points' to the prompt tensor, so I think to pad a 2-box prompt to match the shape of a 3 box prompt, you'd need to do something like:

# Pad 2-box prompt encoding to match 3-box encoding shape
pad_embed = predictor.model.prompt_encoder.not_a_point_embed
sparse_embeddings = torch.cat([sparse_embeddings, pad_embed, pad_embed], dim=1)

This would require modifying the sparse_embeddings that are generated by the prompted encoder (which normally happens inside the predict function).

heyoeyo avatar Sep 25 '24 14:09 heyoeyo