segment-anything icon indicating copy to clipboard operation
segment-anything copied to clipboard

what is image size to get embedding from encoder

Open ramdhan1989 opened this issue 1 year ago • 2 comments

Hi, I try to use encoder part for my backbone for another downstream task. I used below network architecture.

class Net(nn.Module):
    def __init__(self,model_type,sam_checkpoint):
        super(Net, self).__init__()
        self.model = sam_model_registry[model_type](checkpoint=sam_checkpoint)
        for param in self.model.image_encoder.parameters():
          param.requires_grad = False
        self.model = self.model.image_encoder
        self.avgpool = nn.Sequential(nn.Flatten(),nn.Linear(1024, 350),nn.ReLU(inplace=True),nn.Dropout(p=0.1))
        self.fc = nn.Linear(350, 1)

    def forward(self, image):
        x = self.model(image)
        x = self.avgpool(x)
        x = self.fc(x)
        return x

However, I got this error :

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_295/2872027763.py in <module>
      8     print('------------------------------- Epoch: '+str(epoch))
      9 
---> 10     train_fn(train_loader, model, opt, loss_fn)
     11     new_loss = check_acc(val_loader, model)
     12     if new_loss < loss:

/tmp/ipykernel_295/825476113.py in train_fn(loader, model, opt, loss_fn)
      4         y = y.to(torch.float).unsqueeze(1).to('cuda')
      5 
----> 6         preds = model(x).to(torch.float)
      7 
      8         loss = loss_fn(preds, y)

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1188         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1189                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190             return forward_call(*input, **kwargs)
   1191         # Do not call functions when jit is used
   1192         full_backward_hooks, non_full_backward_hooks = [], []

/tmp/ipykernel_295/2068463466.py in forward(self, image)
     11 
     12     def forward(self, image):
---> 13         x = self.model(image)
     14         x = self.avgpool(x)
     15         x = self.fc(x)

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1188         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1189                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190             return forward_call(*input, **kwargs)
   1191         # Do not call functions when jit is used
   1192         full_backward_hooks, non_full_backward_hooks = [], []

/opt/conda/lib/python3.7/site-packages/segment_anything/modeling/image_encoder.py in forward(self, x)
    107         x = self.patch_embed(x)
    108         if self.pos_embed is not None:
--> 109             x = x + self.pos_embed
    110 
    111         for blk in self.blocks:

RuntimeError: The size of tensor a (32) must match the size of tensor b (64) at non-singleton dimension 2

I tried different size of image like 3,512,512 but the error are same. any idea?

Thanks

ramdhan1989 avatar Apr 10 '23 09:04 ramdhan1989

Hi,

The error you're encountering is due to a mismatch in tensor sizes when adding the positional embeddings to the input tensor in the forward method of the image_encoder:

python

x = x + self.pos_embed To resolve this error, you'll need to adjust the positional embeddings in the image_encoder to match the size of the input images. Here's a suggestion on how to do that:

In the Net class's init method, after loading the SAM model, modify the positional embeddings of the image_encoder according to your input image size. python

input_image_size = 512 # Update this value based on your input image size num_patches = (input_image_size // 16) ** 2 pos_embed_shape = self.model.image_encoder.pos_embed.shape new_pos_embed = torch.zeros(pos_embed_shape[0], num_patches, pos_embed_shape[2], device=self.model.image_encoder.pos_embed.device) self.model.image_encoder.pos_embed = nn.Parameter(new_pos_embed) This code creates a new positional embedding tensor of the appropriate size and assigns it to the image_encoder. Make sure to update the input_image_size variable based on the actual size of your input images.

After making this change, the model should be able to process input images of the desired size without encountering the mentioned error.

Aryan-Mishra24 avatar Apr 10 '23 09:04 Aryan-Mishra24

1024x1024 is ok

class Net(nn.Module):
    def __init__(self,model_type,sam_checkpoint):
        super(Net, self).__init__()
        self.model = sam_model_registry[model_type](checkpoint=sam_checkpoint)
        for param in self.model.image_encoder.parameters():
          param.requires_grad = False
        self.model = self.model.image_encoder

    def forward(self, image):
        x = self.model(image)
        print(x.shape)# torch.Size([4, 256, 64, 64])
        return x

if __name__ =="__main__":
    ckpt_path=r"sam_vit_b_01ec64.pth"
    model=Net('vit_b',ckpt_path).cuda()
    print(model)
    x=torch.randn(4,3,1024,1024).cuda()
    res=model(x)

hans0809 avatar Apr 10 '23 09:04 hans0809

Hi @Aryan-Mishra24, can u please provide working example based on your comment, for batch size > 1, and image size different than 1024x1024 as @hans0809 wrote.

It throws error when i try to change the pos_embed for image size 512x512, and imgs tensor dim: (2, 3, 512, 512)

I just used your code:

input_image_size = 512 # Update this value based on your input image size num_patches = (input_image_size // 16) ** 2

pos_embed_shape = model.pos_embed.shape

new_pos_embed = torch.zeros(pos_embed_shape[0], num_patches, pos_embed_shape[2], device=model.pos_embed.device)

model.pos_embed = nn.Parameter(new_pos_embed)

features = model.image_encoder(imgs)

And it gives me: RuntimeError: The size of tensor a (1280) must match the size of tensor b (64) at non-singleton dimension 3 Thanks in advance :) !

marjanstoimchev avatar May 15 '23 14:05 marjanstoimchev

https://github.com/ByungKwanLee/Full-Segment-Anything addresses the ciritical issues of SAM, which supports batch-input on the full-grid prompt (automatic mask generation) with post-processing: removing duplicated or small regions and holes, under flexible input image size

ByungKwanLee avatar Oct 13 '23 21:10 ByungKwanLee