segment-anything icon indicating copy to clipboard operation
segment-anything copied to clipboard

"Empty" box input

Open 25benjaminli opened this issue 10 months ago • 3 comments

Is there a way to add empty box input in batched box input to the prompt encoder? Currently, I have to make a bounding box around the whole image, and I'm wondering whether that might lead to overfitting if I'm training the model.

25benjaminli avatar Apr 18 '24 02:04 25benjaminli

It should be possible to set the box input to None to disable it (i.e. prevent a box prompt from influencing the output).

There's also a (learned) not a point embedding inside the model that you might want to experiment with, since it seems a bit like an 'empty' input. Unfortunately it requires editing the code to make use of it, but it's just a matter of replacing the sparse_embeddings output from the prompt encoder, with something like:

# Overwrite sparse embeddings with the 'not a point' embedding
num_batches = sparse_embeddings.shape[0]
not_point_weight = self.model.prompt_encoder.not_a_point_embed.weight
sparse_embeddings = not_point_weight.repeat(num_batches,1,1)

(this would completely wipe out all prompts and use only the 'not a point' embedding, which may not be desirable if you have points/labels that are part of the prompt)

heyoeyo avatar Apr 18 '24 13:04 heyoeyo

@heyoeyo thanks for the response. As you pointed out, it's not difficult to disable batch prompts altogether, but do you see a way to "selectively allow" prompting within batches?

For example, this might be relevant in a situation with a batch of 8 images and four of them have bounding box prompts and the other four do not.

I was thinking of modifying the code in the prompt encoder's forward to the following:

(-) box_embeddings = self._embed_boxes(boxes)
(-) sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)

(+) for i in range(bs):
(+)    if torch.all(boxes[i, 0] == 0):  # Check if the box is all zeros (this indicates that the box prompt should be ignored)
(+)         # If all zeros, skip adding the box embedding
(+)         continue
(+)     else:
(+)         sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings[i:i+1]], dim=1)

25benjaminli avatar Apr 18 '24 14:04 25benjaminli

a way to "selectively allow" prompting within batches

As far as I can tell, the sparse embeddings have a shape of BxNx256, where N is related to the number of points and whether a box is given (it's something like: 1 + num points + 2 if a box is given). The interpretation being that I have 'B' N-point prompts that I'd like to process simultaneously.

For the example you gave of 8 images with 4 having boxes, my interpretation would be that this is like saying: 'I have 8 N-point prompts, except 4 of them are (N-2)'. So it's not clear what shape the BxNx256 should be in this case. If you wanted to completely avoid issues with interpreting 'empty' boxes, I think the 'purest' approach is to handle it as two separate batches, one 4xNx256 and one 4x(N-2)x256.

Using smaller batches would slow down training though, since it would force more back-and-forth between the cpu/gpu. If that's a problem, then the only option I can think of is to pad the smaller entries in each batch to match the largest entry. I think your 'box around the whole image' is already doing this, but another (maybe better?) option would be to pad using 2 of the 'not a point' embeddings whenever there is no box given. I'm assuming that that's why those embeddings are even part of the model to begin with, but I don't know for sure.

heyoeyo avatar Apr 19 '24 15:04 heyoeyo