segment-anything
segment-anything copied to clipboard
Inference on a batch of images
Hi,
I am trying to embed SAM in a custom framework to perform the segmentation of batched images, given a point for each image. This is a snippet of code to reproduce the error:
image = torch.rand(4, 3, 721, 721) * 255
fw_feat = torch.rand(4, 2, 721, 721)
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)
original_size = image.shape[-2:]
size = sam.image_encoder.img_size
image = F.interpolate(image, size=(size, size), mode='bilinear', align_corners=False)
image = image.to(torch.uint8)
print("Image: ", image.shape)
predictor.set_torch_image(image, original_size)
print("Internal feat: ", predictor.features.shape)
#EXTRACT PROMPT
foreground_pred = F.interpolate(fw_feat[:,1,:,:].unsqueeze(1), size=(size, size), mode='bilinear', align_corners=False)
mask = torch.argmax(fw_feat, dim=1)
#print("foreground_pred: ", torch.unique(mask))
#foreground_pred[mask==0] = 0
print("foreground_pred ", foreground_pred.shape)
foreground_pred = foreground_pred.squeeze(1)
max_val, max_pos_2 = torch.max(foreground_pred, dim=2)
print("Max mac pos 2: ", max_pos_2.shape)
max_val, max_pos_1 = torch.max(max_val, dim=1)
print("Max mac pos 1: ", max_pos_1.shape)
max_positions = []
for i, pos_1 in enumerate(max_pos_1):
pos_2 = max_pos_2[i, pos_1]
max_positions.append([pos_1.item(), pos_2.item()])
max_positions = np.array(max_positions)
print(max_positions)
input_point = torch.tensor(max_positions, device=device).unsqueeze(1)
point_label = torch.ones(input_point.shape[0], 1)
print("Input label: ", input_point.shape)
print("Point label: ", point_label.shape)
masks, scores, logits = predictor.predict_torch(
point_coords=input_point,
point_labels=point_label,
boxes=None,
multimask_output=False,
)
The output I get is the following one:
Image: torch.Size([4, 3, 1024, 1024])
Internal feat: torch.Size([4, 256, 64, 64])
fw feat: torch.Size([4, 721, 721])
foreground_pred torch.Size([4, 1, 1024, 1024])
Max mac pos 2: torch.Size([4, 1024])
Max mac pos 1: torch.Size([4])
[[648 693]
[534 659]
[284 352]
[773 580]]
Input label: torch.Size([4, 1, 2])
Point label: torch.Size([4, 1])
0%| | 0/9 [00:08<?, ?it/s]
Traceback (most recent call last):
File "train.py", line 589, in <module>
run()
File "train.py", line 526, in run
train_results = train(epoch, net, device, train_data, optimizer, args.batches_per_epoch, args, vis=args.vis)
File "train.py", line 352, in train
lossd = net.compute_loss(depth, yc_semantic, rgb, label, s_x, s_y, s_init_seed, device, type="train",ep=epoch)
File "/home/bacchin/fsgnet2_venv/FSGNet_2.0/model/backbone_pos_angle_width.py", line 529, in compute_loss
pos_pred, fs_out = self(features, rgb_x, support_x, support_y, s_seed)
File "/home/bacchin/fsgnet2_venv/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/bacchin/fsgnet2_venv/FSGNet_2.0/model/backbone_pos_angle_width.py", line 490, in forward
self.refineWithSAM(fs_out, rgb_x)
File "/home/bacchin/fsgnet2_venv/FSGNet_2.0/model/backbone_pos_angle_width.py", line 460, in refineWithSAM
multimask_output=False,
File "/home/bacchin/fsgnet2_venv/lib/python3.6/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
return func(*args, **kwargs)
File "/home/bacchin/fsgnet2_venv/lib/python3.6/site-packages/segment_anything/predictor.py", line 234, in predict_torch
multimask_output=multimask_output,
File "/home/bacchin/fsgnet2_venv/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/bacchin/fsgnet2_venv/lib/python3.6/site-packages/segment_anything/modeling/mask_decoder.py", line 98, in forward
dense_prompt_embeddings=dense_prompt_embeddings,
File "/home/bacchin/fsgnet2_venv/lib/python3.6/site-packages/segment_anything/modeling/mask_decoder.py", line 127, in predict_masks
src = src + dense_prompt_embeddings
RuntimeError: The size of tensor a (16) must match the size of tensor b (4) at non-singleton dimension 0
Can someone help me to understand what is wrong?
Thank you!
segment_anything/modeling/mask_decoder.py -- line126-127
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
src = src + dense_prompt_embeddings
All image_embeddings will be copied 4 times. [4, 256, 64, 64] -> [4*4=16, 256, 64, 64]
But the shape of dense_prompt_embeddings
is [4, 256, 64, 64]
A possible solution:
pos_src = torch.repeat_interleave(image_pe, src.shape[0], dim=0)
sparse_prompt_embeddings = torch.repeat_interleave(sparse_prompt_embeddings, tokens.shape[0], dim=0)
dense_prompt_embeddings = torch.repeat_interleave(dense_prompt_embeddings, tokens.shape[0], dim=0)
tokens = torch.repeat_interleave(tokens, tokens.shape[0], dim=0)
Thank you for your answer!
I also found this workaround:
#src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
src = image_embeddings
src = src + dense_prompt_embeddings
Basically I remove the duplication. Which solution do you think is better?
segment_anything/modeling/mask_decoder.py -- line126-127
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) src = src + dense_prompt_embeddings
All image_embeddings will be copied 4 times. [4, 256, 64, 64] -> [4*4=16, 256, 64, 64] But the shape of
dense_prompt_embeddings
is [4, 256, 64, 64]A possible solution:
pos_src = torch.repeat_interleave(image_pe, src.shape[0], dim=0) sparse_prompt_embeddings = torch.repeat_interleave(sparse_prompt_embeddings, tokens.shape[0], dim=0) dense_prompt_embeddings = torch.repeat_interleave(dense_prompt_embeddings, tokens.shape[0], dim=0) tokens = torch.repeat_interleave(tokens, tokens.shape[0], dim=0)
Attention, the usage of repeat_interleave
is torch.repeat_interleave(input, repeats, dim=None) → Tensor
. The second parameter means the number of repeating your input. So it should be:
# Uniform Size
src = torch.repeat_interleave(image_embeddings, tokens.shape[0]//image_embeddings.shape[0], dim=0)
dense_prompt_embeddings = torch.repeat_interleave(dense_prompt_embeddings, tokens.shape[0]//dense_prompt_embeddings.shape[0], dim=0)
eans the number of repeating your input
Why src
and prompt
need repeating?
可能我英语表述不太清楚,decoder需要将图像编码和prompt编码统一然后相加,统一的标准就是token的大小。您还有问题的话可以描述的具体一些,我看到会及时回复。
Maybe my English expression is not clear, decoder needs to unify the image encoding and prompt encoding and then add them, and the unified standard is the size of the token. If you still have questions, you can describe some specifics, I see that I will reply in time.
same issue.
I noticed that in MedSAM, they do the same modification as @bach05.
compared to repeat all the embeddings, this seems to be more reliable?
any suggestions?
segment_anything/modeling/mask_decoder.py -- line126-127
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) src = src + dense_prompt_embeddings
All image_embeddings will be copied 4 times. [4, 256, 64, 64] -> [4*4=16, 256, 64, 64] But the shape of
dense_prompt_embeddings
is [4, 256, 64, 64] A possible solution:pos_src = torch.repeat_interleave(image_pe, src.shape[0], dim=0) sparse_prompt_embeddings = torch.repeat_interleave(sparse_prompt_embeddings, tokens.shape[0], dim=0) dense_prompt_embeddings = torch.repeat_interleave(dense_prompt_embeddings, tokens.shape[0], dim=0) tokens = torch.repeat_interleave(tokens, tokens.shape[0], dim=0)
Attention, the usage of
repeat_interleave
is torch.repeat_interleave(input, repeats, dim=None) → Tensor
. The second parameter means the number of repeating your input. So it should be:# Uniform Size src = torch.repeat_interleave(image_embeddings, tokens.shape[0]//image_embeddings.shape[0], dim=0) dense_prompt_embeddings = torch.repeat_interleave(dense_prompt_embeddings, tokens.shape[0]//dense_prompt_embeddings.shape[0], dim=0)
如果我每一张图想要多个points生成多个mask呢,如何输入B C H W 得到掩膜 B D H W (D是我要的mask数量)
same issue.
I noticed that in MedSAM, they do the same modification as @bach05.
compared to repeat all the embeddings, this seems to be more reliable?
any suggestions?
请问我如何不给 input_prompts,而batch的自动生成所有mask呢 如何调用automatic_mask_generator?
same issue. I noticed that in MedSAM, they do the same modification as @bach05. compared to repeat all the embeddings, this seems to be more reliable? any suggestions?
请问我如何不给 input_prompts,而batch的自动生成所有mask呢 如何调用automatic_mask_generator?
我也有这个需求,请问你解决了吗?
same issue.
I noticed that in MedSAM, they do the same modification as @bach05.
compared to repeat all the embeddings, this seems to be more reliable?
any suggestions?
I can suggest to not use the predictor, which is an envelope for other underlying methods. If you use directly image_encoder, prompt_encoder, mask_decoder methods you can feed batched tensors.
same issue.
I noticed that in MedSAM, they do the same modification as @bach05.
compared to repeat all the embeddings, this seems to be more reliable?
any suggestions?
Of course you can, the key is standardizing the size of each token which makes sure the procedure keeps going. I prefered to copy tokens because I'm not sure that follow-up procedure is sensitive to the size of tokens. If it is checked that the way you mentioned makes sense, I would prefer to use that way to reduce memory usage of CUDA.
segment_anything/modeling/mask_decoder.py -- line126-127
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) src = src + dense_prompt_embeddings
All image_embeddings will be copied 4 times. [4, 256, 64, 64] -> [4*4=16, 256, 64, 64] But the shape of
dense_prompt_embeddings
is [4, 256, 64, 64] A possible solution:pos_src = torch.repeat_interleave(image_pe, src.shape[0], dim=0) sparse_prompt_embeddings = torch.repeat_interleave(sparse_prompt_embeddings, tokens.shape[0], dim=0) dense_prompt_embeddings = torch.repeat_interleave(dense_prompt_embeddings, tokens.shape[0], dim=0) tokens = torch.repeat_interleave(tokens, tokens.shape[0], dim=0)
Attention, the usage of
repeat_interleave
is torch.repeat_interleave(input, repeats, dim=None) → Tensor
. The second parameter means the number of repeating your input. So it should be:# Uniform Size src = torch.repeat_interleave(image_embeddings, tokens.shape[0]//image_embeddings.shape[0], dim=0) dense_prompt_embeddings = torch.repeat_interleave(dense_prompt_embeddings, tokens.shape[0]//dense_prompt_embeddings.shape[0], dim=0)
如果我每一张图想要多个points生成多个mask呢,如何输入B C H W 得到掩膜 B D H W (D是我要的mask数量)
我认为您所说的会是SAM很好的应用,但是现在demo中似乎没有提供这样的设计,我们无法只通过prompt数量控制mask的数量。我现在控制mask数量的方式包括修改以下参数以及前向过程中的multimask_output
值得一提的是,目前SAM的自动生成api似乎符合您要求的多prompt输入多mask输出,只是没有一一对应关系,且不提供批处理,我们目前致力于完成一个支持批处理的
SamAutomaticMaskGenerator
,如果您对我们的工作感兴趣的话欢迎加入我们的开发。
以上是我个人使用后的心得体会,没有具体地参考代码,有不足之处请指正。
https://github.com/ByungKwanLee/Full-Segment-Anything addresses the ciritical issues of SAM, which supports batch-input on the full-grid prompt (automatic mask generation) with post-processing: removing duplicated or small regions and holes, under flexible input image size
eans the number of repeating your input
Why
src
andprompt
need repeating?
because the image is not always the whole image as AMG crops sections of the image.