segment-anything icon indicating copy to clipboard operation
segment-anything copied to clipboard

Multiple Points, Labels, and Boxes while Batch Prompting

Open Jordan-Pierce opened this issue 1 year ago • 8 comments

Thanks again for the release, very useful!

I'm currently trying to do batch prompting by providing bounding boxes (creating masks for objects already annotated), and I've noticed that sometimes the boxes alone do not create entire masks. An idea to solve this was to provide a bounding box, and multiple points sampled within that box, and hopefully together that would create a better mask for those edge cases.

The notebook provides a clear example of how to perform a prediction for a single point, and a single bounding box, as well as multiple bounding boxes, but not multiple points (w/ labels) and bounding boxes. When trying to do this, I keep running into an error and it's not clear if I'm doing it incorrectly, or if you cannot do this. Below is an example of what I thought would work:

image = cv2.imread(cocodict['images'][0]['path'])
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

input_boxes = torch.tensor([b['bbox'] for b in cocodict['annotations']], device=predictor.device)

# Returns a list of 10 points, per bounding box, np array
input_points = sample_points(input_boxes, 10)
input_points = torch.tensor(input_points, device=predictor.device)

transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])
transformed_points = predictor.transform.apply_points_torch(input_points, image.shape[:2])
transformed_labels = torch.tensor([1 for _ in range(len(input_points))], device=predictor.device)

predictor.set_image(image)

masks, _, _ = predictor.predict_torch(
    point_coords=transformed_points,
    point_labels=transformed_labels,
    boxes=transformed_boxes,
    multimask_output=False,
)

I understand that I can get the results I'm looking for by breaking up the process one-by-one and just join the masks, but I'd like to know if there is a working solution for this approach. Thanks.

Jordan-Pierce avatar Apr 08 '23 15:04 Jordan-Pierce

@Jordan-Pierce Is this the same issue ? https://github.com/facebookresearch/segment-anything/issues/115

zdhernandez avatar Apr 08 '23 16:04 zdhernandez

Hi @zdhernandez, not quite. Thanks for the response though.

JordanMakesMaps avatar Apr 08 '23 17:04 JordanMakesMaps

Did you check the shape of your transformed_points and transformed_labels? transformed_points should be a a BxNx2 tensor and transformed_labels should be a BxN tensor. N is the number of sampled points here.

HannaMao avatar Apr 09 '23 04:04 HannaMao

@HannaMao @Jordan-Pierce I have the same issue. I found that the number of points and boxes must be the same, otherwise error happens (8 boxes and 7 points):

Traceback (most recent call last):
  File "D:\python\lib\site-packages\gradio\routes.py", line 395, in run_predict
    output = await app.get_blocks().process_api(
  File "D:\python\lib\site-packages\gradio\blocks.py", line 1193, in process_api
    result = await self.call_function(
  File "D:\python\lib\site-packages\gradio\blocks.py", line 916, in call_function
    prediction = await anyio.to_thread.run_sync(
  File "D:\python\lib\site-packages\anyio\to_thread.py", line 31, in run_sync
    return await get_asynclib().run_sync_in_worker_thread(
  File "D:\python\lib\site-packages\anyio\_backends\_asyncio.py", line 937, in run_sync_in_worker_thread
    return await future
  File "D:\python\lib\site-packages\anyio\_backends\_asyncio.py", line 867, in run
    result = context.run(func, *args)
  File "F:\code\segment_anything_webui\inference.py", line 183, in run_inference
    return predictor_inference(device, model_type, input_x, input_text, selected_points, owl_vit_threshold)
  File "F:\code\segment_anything_webui\inference.py", line 145, in predictor_inference
    masks, scores, logits = predictor.predict_torch(
  File "D:\python\lib\site-packages\torch\autograd\grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "D:\python\lib\site-packages\segment_anything\predictor.py", line 222, in predict_torch
    sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
  File "D:\python\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "D:\python\lib\site-packages\segment_anything\modeling\prompt_encoder.py", line 162, in forward
    sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 7 but got size 8 for tensor number 1 in the list.

5663015 avatar Apr 14 '23 16:04 5663015

@5663015 thanks for double checking; @HannaMao can you confirm or deny what we're trying to do is possible given the current implementation of SAM?

Alternatively, we can always run it twice, with points, and with images, and find a way to join masks.

Jordan-Pierce avatar Apr 14 '23 16:04 Jordan-Pierce

@nikhilaravi @Jordan-Pierce im exactly looking the same. Have you solved this issue? my object detector is detecting two objects but my SAM model only detecting a mask on 1 object, how can we do it on multiple objects in single image?

akashAD98 avatar Apr 18 '23 14:04 akashAD98

I guess at the moment you need to run decoder on every single object separately. I am trying the same approach when detector (yolo) detects multiple objects and then coordinates of those boxes passed to the decoder as:

                        point = index * 2
                        point_coords[0, point, 0] = rect.Left
                        point_coords[0, point, 1] = rect.Top
                        point_coords[0, point + 1, 0] = rect.Right
                        point_coords[0, point + 1, 1] = rect.Bottom 

                        point_labels[0, point]     = 2
                        point_labels[0, point + 1] = 3

When its only 1 object everything works fine, but it returns complete mess when there are more than 1 object detected by the object detector.

Please tell me I am wrong about it, and simply doing mistakes when inferencing SAM detector.

VladMVLX avatar Apr 30 '23 15:04 VladMVLX

Has apply_points_torch() been removed? I am unable to find it. Using the code above gives me an AttributeError.

AttributeError: 'ResizeLongestSide' object has no attribute 'apply_points_torch'

What should be used instead to perform batch prompting using points?

kjtheron avatar May 05 '23 18:05 kjtheron

Did you check the shape of your transformed_points and transformed_labels? transformed_points should be a a BxNx2 tensor and transformed_labels should be a BxN tensor. N is the number of sampled points here.

From SAM github, point_coords (torch.Tensor or None): BxNx2 point_labels (torch.Tensor or None): A BxN array of labels boxes (np.ndarray or None): A Bx4 array given a box prompt to the model, in XYXY format.

Can I ask what B and N mean here in BxNx2 tensor size? Also, points and boxes share B as a common parameter. So, do I need the same number of points and boxes?

Thank you.

stharajkiran avatar May 15 '23 06:05 stharajkiran

Did you check the shape of your transformed_points and transformed_labels? transformed_points should be a a BxNx2 tensor and transformed_labels should be a BxN tensor. N is the number of sampled points here.

From SAM github, point_coords (torch.Tensor or None): BxNx2 point_labels (torch.Tensor or None): A BxN array of labels boxes (np.ndarray or None): A Bx4 array given a box prompt to the model, in XYXY format.

Can I ask what B and N mean here in BxNx2 tensor size? Also, points and boxes share B as a common parameter. So, do I need the same number of points and boxes?

Thank you.

The first question: B: number of boxes N: number of point samples for each box

The second question: No, for each box you need a number of points (I think you need a fixed number of points for each box, it means you can't use 4 points for the first box and 5 points for the second box, but I'm not sure about that)

JasseurHadded1 avatar May 16 '23 10:05 JasseurHadded1

That is correct, each box needs a sequence of point coords.

HobbitArmy avatar Jul 08 '23 04:07 HobbitArmy

https://github.com/ByungKwanLee/Full-Segment-Anything addresses the ciritical issues of SAM, which supports batch-input on the full-grid prompt (automatic mask generation) with post-processing: removing duplicated or small regions and holes, under flexible input image size

ByungKwanLee avatar Oct 16 '23 07:10 ByungKwanLee

Thanks @ByungKwanLee

Jordan-Pierce avatar Oct 16 '23 13:10 Jordan-Pierce

Did you check the shape of your transformed_points and transformed_labels? transformed_points should be a a BxNx2 tensor and transformed_labels should be a BxN tensor. N is the number of sampled points here.

From SAM github, point_coords (torch.Tensor or None): BxNx2 point_labels (torch.Tensor or None): A BxN array of labels boxes (np.ndarray or None): A Bx4 array given a box prompt to the model, in XYXY format. Can I ask what B and N mean here in BxNx2 tensor size? Also, points and boxes share B as a common parameter. So, do I need the same number of points and boxes? Thank you.

The first question: B: number of boxes N: number of point samples for each box

The second question: No, for each box you need a number of points (I think you need a fixed number of points for each box, it means you can't use 4 points for the first box and 5 points for the second box, but I'm not sure about that)

image

Here, I have met all the size requirements, still I get an error saying "too many indices for tensor of dimension 3" for the points. I have no idea how to figure the issue. if i use only the boxes, it runs fine though!

Here, B = 4, N = 1

stharajkiran avatar Oct 28 '23 19:10 stharajkiran

Here, I have met all the size requirements, still I get an error saying "too many indices for tensor of dimension 3" for the points. I have no idea how to figure the issue. if i use only the boxes, it runs fine though!

It's hard to say without seeing the error, but one possible issue is that the predict_torch(...) function is being given point_labels= input_label instead of transformed_labels. Otherwise, the input shapes seem fine (I can run this without errors using randomly generated data of the same shape at least).

heyoeyo avatar Oct 29 '23 15:10 heyoeyo