segment-anything
segment-anything copied to clipboard
"Empty" box input
Is there a way to add empty box input in batched box input to the prompt encoder? Currently, I have to make a bounding box around the whole image, and I'm wondering whether that might lead to overfitting if I'm training the model.
It should be possible to set the box input to None
to disable it (i.e. prevent a box prompt from influencing the output).
There's also a (learned) not a point embedding inside the model that you might want to experiment with, since it seems a bit like an 'empty' input. Unfortunately it requires editing the code to make use of it, but it's just a matter of replacing the sparse_embeddings
output from the prompt encoder, with something like:
# Overwrite sparse embeddings with the 'not a point' embedding
num_batches = sparse_embeddings.shape[0]
not_point_weight = self.model.prompt_encoder.not_a_point_embed.weight
sparse_embeddings = not_point_weight.repeat(num_batches,1,1)
(this would completely wipe out all prompts and use only the 'not a point' embedding, which may not be desirable if you have points/labels that are part of the prompt)
@heyoeyo thanks for the response. As you pointed out, it's not difficult to disable batch prompts altogether, but do you see a way to "selectively allow" prompting within batches?
For example, this might be relevant in a situation with a batch of 8 images and four of them have bounding box prompts and the other four do not.
I was thinking of modifying the code in the prompt encoder's forward
to the following:
(-) box_embeddings = self._embed_boxes(boxes)
(-) sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
(+) for i in range(bs):
(+) if torch.all(boxes[i, 0] == 0): # Check if the box is all zeros (this indicates that the box prompt should be ignored)
(+) # If all zeros, skip adding the box embedding
(+) continue
(+) else:
(+) sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings[i:i+1]], dim=1)
a way to "selectively allow" prompting within batches
As far as I can tell, the sparse embeddings have a shape of BxNx256, where N is related to the number of points and whether a box is given (it's something like: 1 + num points + 2 if a box is given). The interpretation being that I have 'B' N-point prompts that I'd like to process simultaneously.
For the example you gave of 8 images with 4 having boxes, my interpretation would be that this is like saying: 'I have 8 N-point prompts, except 4 of them are (N-2)'. So it's not clear what shape the BxNx256 should be in this case. If you wanted to completely avoid issues with interpreting 'empty' boxes, I think the 'purest' approach is to handle it as two separate batches, one 4xNx256 and one 4x(N-2)x256.
Using smaller batches would slow down training though, since it would force more back-and-forth between the cpu/gpu. If that's a problem, then the only option I can think of is to pad the smaller entries in each batch to match the largest entry. I think your 'box around the whole image' is already doing this, but another (maybe better?) option would be to pad using 2 of the 'not a point' embeddings whenever there is no box given. I'm assuming that that's why those embeddings are even part of the model to begin with, but I don't know for sure.