segment-anything icon indicating copy to clipboard operation
segment-anything copied to clipboard

Inference on a batch of images

Open bach05 opened this issue 1 year ago • 14 comments

Hi,

I am trying to embed SAM in a custom framework to perform the segmentation of batched images, given a point for each image. This is a snippet of code to reproduce the error:

    image = torch.rand(4, 3, 721, 721) * 255
   fw_feat = torch.rand(4, 2, 721, 721) 

    sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
        sam.to(device=device)

    predictor = SamPredictor(sam)

    original_size = image.shape[-2:]
    size = sam.image_encoder.img_size

    image = F.interpolate(image, size=(size, size), mode='bilinear', align_corners=False)
    image = image.to(torch.uint8)
    print("Image: ", image.shape)
    predictor.set_torch_image(image, original_size)
    print("Internal feat: ", predictor.features.shape)

    #EXTRACT PROMPT
    foreground_pred = F.interpolate(fw_feat[:,1,:,:].unsqueeze(1), size=(size, size), mode='bilinear', align_corners=False)
    mask = torch.argmax(fw_feat, dim=1)
    #print("foreground_pred: ", torch.unique(mask))
    #foreground_pred[mask==0] = 0

    print("foreground_pred ", foreground_pred.shape)
    foreground_pred = foreground_pred.squeeze(1)
    max_val, max_pos_2 = torch.max(foreground_pred, dim=2)
    print("Max mac pos 2: ", max_pos_2.shape)
    max_val, max_pos_1 = torch.max(max_val, dim=1)
    print("Max mac pos 1: ", max_pos_1.shape)

    max_positions = []
    for i, pos_1 in enumerate(max_pos_1):
        pos_2 = max_pos_2[i, pos_1]
        max_positions.append([pos_1.item(), pos_2.item()])
    max_positions = np.array(max_positions)

    print(max_positions)

    input_point = torch.tensor(max_positions, device=device).unsqueeze(1)
    point_label = torch.ones(input_point.shape[0], 1)

    print("Input label: ", input_point.shape)
    print("Point label: ", point_label.shape)

    masks, scores, logits = predictor.predict_torch(
        point_coords=input_point,
        point_labels=point_label,
        boxes=None,
        multimask_output=False,
    )

The output I get is the following one:

Image:  torch.Size([4, 3, 1024, 1024])
Internal feat:  torch.Size([4, 256, 64, 64])
fw feat:  torch.Size([4, 721, 721])
foreground_pred  torch.Size([4, 1, 1024, 1024])
Max mac pos 2:  torch.Size([4, 1024])
Max mac pos 1:  torch.Size([4])
[[648 693]
 [534 659]
 [284 352]
 [773 580]]
Input label:  torch.Size([4, 1, 2])
Point label:  torch.Size([4, 1])
  0%|                                                                                                                                                                                        | 0/9 [00:08<?, ?it/s]
Traceback (most recent call last):
  File "train.py", line 589, in <module>
    run()
  File "train.py", line 526, in run
    train_results = train(epoch, net, device, train_data, optimizer, args.batches_per_epoch, args, vis=args.vis)
  File "train.py", line 352, in train
    lossd = net.compute_loss(depth, yc_semantic, rgb, label, s_x, s_y, s_init_seed, device, type="train",ep=epoch)
  File "/home/bacchin/fsgnet2_venv/FSGNet_2.0/model/backbone_pos_angle_width.py", line 529, in compute_loss
    pos_pred, fs_out = self(features, rgb_x, support_x, support_y, s_seed)
  File "/home/bacchin/fsgnet2_venv/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/bacchin/fsgnet2_venv/FSGNet_2.0/model/backbone_pos_angle_width.py", line 490, in forward
    self.refineWithSAM(fs_out, rgb_x)
  File "/home/bacchin/fsgnet2_venv/FSGNet_2.0/model/backbone_pos_angle_width.py", line 460, in refineWithSAM
    multimask_output=False,
  File "/home/bacchin/fsgnet2_venv/lib/python3.6/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
    return func(*args, **kwargs)
  File "/home/bacchin/fsgnet2_venv/lib/python3.6/site-packages/segment_anything/predictor.py", line 234, in predict_torch
    multimask_output=multimask_output,
  File "/home/bacchin/fsgnet2_venv/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/bacchin/fsgnet2_venv/lib/python3.6/site-packages/segment_anything/modeling/mask_decoder.py", line 98, in forward
    dense_prompt_embeddings=dense_prompt_embeddings,
  File "/home/bacchin/fsgnet2_venv/lib/python3.6/site-packages/segment_anything/modeling/mask_decoder.py", line 127, in predict_masks
    src = src + dense_prompt_embeddings
RuntimeError: The size of tensor a (16) must match the size of tensor b (4) at non-singleton dimension 0

Can someone help me to understand what is wrong?

Thank you!

bach05 avatar May 16 '23 09:05 bach05

segment_anything/modeling/mask_decoder.py -- line126-127

src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
src = src + dense_prompt_embeddings

All image_embeddings will be copied 4 times. [4, 256, 64, 64] -> [4*4=16, 256, 64, 64] But the shape of dense_prompt_embeddings is [4, 256, 64, 64]

A possible solution:

pos_src = torch.repeat_interleave(image_pe, src.shape[0], dim=0)
sparse_prompt_embeddings = torch.repeat_interleave(sparse_prompt_embeddings, tokens.shape[0], dim=0)
dense_prompt_embeddings = torch.repeat_interleave(dense_prompt_embeddings, tokens.shape[0], dim=0)
tokens = torch.repeat_interleave(tokens, tokens.shape[0], dim=0)

mypydl avatar May 16 '23 13:05 mypydl

Thank you for your answer!

I also found this workaround:

#src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
src = image_embeddings
src = src + dense_prompt_embeddings

Basically I remove the duplication. Which solution do you think is better?

bach05 avatar May 16 '23 13:05 bach05

segment_anything/modeling/mask_decoder.py -- line126-127

src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
src = src + dense_prompt_embeddings

All image_embeddings will be copied 4 times. [4, 256, 64, 64] -> [4*4=16, 256, 64, 64] But the shape of dense_prompt_embeddings is [4, 256, 64, 64]

A possible solution:

pos_src = torch.repeat_interleave(image_pe, src.shape[0], dim=0)
sparse_prompt_embeddings = torch.repeat_interleave(sparse_prompt_embeddings, tokens.shape[0], dim=0)
dense_prompt_embeddings = torch.repeat_interleave(dense_prompt_embeddings, tokens.shape[0], dim=0)
tokens = torch.repeat_interleave(tokens, tokens.shape[0], dim=0)

Attention, the usage of repeat_interleave is torch.repeat_interleave(input, repeats, dim=None) → Tensor. The second parameter means the number of repeating your input. So it should be:

# Uniform Size
src = torch.repeat_interleave(image_embeddings, tokens.shape[0]//image_embeddings.shape[0], dim=0)
dense_prompt_embeddings = torch.repeat_interleave(dense_prompt_embeddings, tokens.shape[0]//dense_prompt_embeddings.shape[0], dim=0)

kong-johnny avatar May 21 '23 19:05 kong-johnny

eans the number of repeating your input

Why src and prompt need repeating?

usherbob avatar May 22 '23 02:05 usherbob

可能我英语表述不太清楚,decoder需要将图像编码和prompt编码统一然后相加,统一的标准就是token的大小。您还有问题的话可以描述的具体一些,我看到会及时回复。

Maybe my English expression is not clear, decoder needs to unify the image encoding and prompt encoding and then add them, and the unified standard is the size of the token. If you still have questions, you can describe some specifics, I see that I will reply in time.

kong-johnny avatar May 22 '23 04:05 kong-johnny

same issue.

I noticed that in MedSAM, they do the same modification as @bach05.

compared to repeat all the embeddings, this seems to be more reliable?

any suggestions?

chaoer avatar Jun 06 '23 05:06 chaoer

segment_anything/modeling/mask_decoder.py -- line126-127

src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
src = src + dense_prompt_embeddings

All image_embeddings will be copied 4 times. [4, 256, 64, 64] -> [4*4=16, 256, 64, 64] But the shape of dense_prompt_embeddings is [4, 256, 64, 64] A possible solution:

pos_src = torch.repeat_interleave(image_pe, src.shape[0], dim=0)
sparse_prompt_embeddings = torch.repeat_interleave(sparse_prompt_embeddings, tokens.shape[0], dim=0)
dense_prompt_embeddings = torch.repeat_interleave(dense_prompt_embeddings, tokens.shape[0], dim=0)
tokens = torch.repeat_interleave(tokens, tokens.shape[0], dim=0)

Attention, the usage of repeat_interleave is torch.repeat_interleave(input, repeats, dim=None) → Tensor. The second parameter means the number of repeating your input. So it should be:

# Uniform Size
src = torch.repeat_interleave(image_embeddings, tokens.shape[0]//image_embeddings.shape[0], dim=0)
dense_prompt_embeddings = torch.repeat_interleave(dense_prompt_embeddings, tokens.shape[0]//dense_prompt_embeddings.shape[0], dim=0)

如果我每一张图想要多个points生成多个mask呢,如何输入B C H W 得到掩膜 B D H W (D是我要的mask数量)

chenjjcccc avatar Jun 06 '23 12:06 chenjjcccc

same issue.

I noticed that in MedSAM, they do the same modification as @bach05.

compared to repeat all the embeddings, this seems to be more reliable?

any suggestions?

请问我如何不给 input_prompts,而batch的自动生成所有mask呢 如何调用automatic_mask_generator?

chenjjcccc avatar Jun 06 '23 13:06 chenjjcccc

same issue. I noticed that in MedSAM, they do the same modification as @bach05. compared to repeat all the embeddings, this seems to be more reliable? any suggestions?

请问我如何不给 input_prompts,而batch的自动生成所有mask呢 如何调用automatic_mask_generator?

我也有这个需求,请问你解决了吗?

zhaolei4383 avatar Aug 01 '23 01:08 zhaolei4383

same issue.

I noticed that in MedSAM, they do the same modification as @bach05.

compared to repeat all the embeddings, this seems to be more reliable?

any suggestions?

I can suggest to not use the predictor, which is an envelope for other underlying methods. If you use directly image_encoder, prompt_encoder, mask_decoder methods you can feed batched tensors.

bach05 avatar Aug 01 '23 06:08 bach05

same issue.

I noticed that in MedSAM, they do the same modification as @bach05.

compared to repeat all the embeddings, this seems to be more reliable?

any suggestions?

Of course you can, the key is standardizing the size of each token which makes sure the procedure keeps going. I prefered to copy tokens because I'm not sure that follow-up procedure is sensitive to the size of tokens. If it is checked that the way you mentioned makes sense, I would prefer to use that way to reduce memory usage of CUDA.

kong-johnny avatar Aug 02 '23 15:08 kong-johnny

segment_anything/modeling/mask_decoder.py -- line126-127

src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
src = src + dense_prompt_embeddings

All image_embeddings will be copied 4 times. [4, 256, 64, 64] -> [4*4=16, 256, 64, 64] But the shape of dense_prompt_embeddings is [4, 256, 64, 64] A possible solution:

pos_src = torch.repeat_interleave(image_pe, src.shape[0], dim=0)
sparse_prompt_embeddings = torch.repeat_interleave(sparse_prompt_embeddings, tokens.shape[0], dim=0)
dense_prompt_embeddings = torch.repeat_interleave(dense_prompt_embeddings, tokens.shape[0], dim=0)
tokens = torch.repeat_interleave(tokens, tokens.shape[0], dim=0)

Attention, the usage of repeat_interleave is torch.repeat_interleave(input, repeats, dim=None) → Tensor. The second parameter means the number of repeating your input. So it should be:

# Uniform Size
src = torch.repeat_interleave(image_embeddings, tokens.shape[0]//image_embeddings.shape[0], dim=0)
dense_prompt_embeddings = torch.repeat_interleave(dense_prompt_embeddings, tokens.shape[0]//dense_prompt_embeddings.shape[0], dim=0)

如果我每一张图想要多个points生成多个mask呢,如何输入B C H W 得到掩膜 B D H W (D是我要的mask数量)

我认为您所说的会是SAM很好的应用,但是现在demo中似乎没有提供这样的设计,我们无法只通过prompt数量控制mask的数量。我现在控制mask数量的方式包括修改以下参数以及前向过程中的multimask_output image 值得一提的是,目前SAM的自动生成api似乎符合您要求的多prompt输入多mask输出,只是没有一一对应关系,且不提供批处理,我们目前致力于完成一个支持批处理的SamAutomaticMaskGenerator,如果您对我们的工作感兴趣的话欢迎加入我们的开发。 以上是我个人使用后的心得体会,没有具体地参考代码,有不足之处请指正。

kong-johnny avatar Aug 02 '23 16:08 kong-johnny

https://github.com/ByungKwanLee/Full-Segment-Anything addresses the ciritical issues of SAM, which supports batch-input on the full-grid prompt (automatic mask generation) with post-processing: removing duplicated or small regions and holes, under flexible input image size

ByungKwanLee avatar Oct 13 '23 21:10 ByungKwanLee

eans the number of repeating your input

Why src and prompt need repeating?

because the image is not always the whole image as AMG crops sections of the image.

jez-moxmo avatar Dec 29 '23 05:12 jez-moxmo