segment-anything icon indicating copy to clipboard operation
segment-anything copied to clipboard

Make `ImageEncoderViT` compilable with `torch.jit.script`

Open fkodom opened this issue 1 year ago • 6 comments

It would be very convenient if ImageEncoderViT was compilable with torch.jit.script.

Issue

Load the Sam model, and attempt to compile with torch.jit.script:

import torch
from segment_anything import sam_model_registry

model_type = "vit_b"
sam = sam_model_registry[model_type](checkpoint="...")
scripted = torch.jit.script(sam.image_encoder)

which raises the following error:

RuntimeError: 

pad_hw is not defined in the false branch:
  File "/home/fodom_plainsight_ai/.pyenv/versions/smartpoly-sam/lib/python3.10/site-packages/segment_anything/modeling/image_encoder.py", line 170
        x = self.norm1(x)
        # Window partition
        if self.window_size > 0:
        ~~~~~~~~~~~~~~~~~~~~~~~~
            H, W = x.shape[1], x.shape[2]
            ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
            x, pad_hw = window_partition(x, self.window_size)
            ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    
        x = self.attn(x)
and was used here:
  File "/home/fodom_plainsight_ai/.pyenv/versions/smartpoly-sam/lib/python3.10/site-packages/segment_anything/modeling/image_encoder.py", line 177
        # Reverse window partition
        if self.window_size > 0:
            x = window_unpartition(x, self.window_size, pad_hw, (H, W))
                                                        ~~~~~~ <--- HERE
    
        x = shortcut + x

Solution

A simple fix is to avoid writing if self.window_size > 0: in two separate places.

fkodom avatar Apr 14 '23 20:04 fkodom

print("PyTorch version:", torch.version) print("Torchvision version:", torchvision.version) print("CUDA is available:", torch.cuda.is_available())

from segment_anything import sam_model_registry import numpy as np

model_checkpoint = '/path/to/sam_vit_b_01ec64.pth' sam = sam_model_registry"vit_b".to(device='cuda')

images = np.zeros((10, 256, 256, 3), dtype='uint8')

batched_input = [] for i in range(images.shape[0]): batched_input.append( { 'image': torch.as_tensor(images[i], device=sam.device).permute(2, 0, 1).contiguous(), 'original_size': images[i].shape[:2], } )

HIMANSHUSINGHYANIA avatar Apr 18 '23 11:04 HIMANSHUSINGHYANIA

pad_hw is not defined in the false branch: File "/home/fodom_plainsight_ai/.pyenv/versions/smartpoly-sam/lib/python3.10/site-packages/segment_anything/modeling/image_encoder.py", line 170 x = self.norm1(x) # Window partition if self.window_size > 0: ~~~~~~~~~~~~~~~~~~~~~~~~ H, W = x.shape[1], x.shape[2] ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ x, pad_hw = window_partition(x, self.window_size) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE

    x = self.attn(x)

and was used here: File "/home/fodom_plainsight_ai/.pyenv/versions/smartpoly-sam/lib/python3.10/site-packages/segment_anything/modeling/image_encoder.py", line 177 # Reverse window partition if self.window_size > 0: x = window_unpartition(x, self.window_size, pad_hw, (H, W))

HIMANSHUSINGHYANIA avatar Apr 18 '23 11:04 HIMANSHUSINGHYANIA

print("PyTorch version:", torch.version) print("Torchvision version:", torchvision.version) print("CUDA is available:", torch.cuda.is_available())

from segment_anything import sam_model_registry import numpy as np

model_checkpoint = '/path/to/sam_vit_b_01ec64.pth' sam = sam_model_registry"vit_b".to(device='cuda')

images = np.zeros((10, 256, 256, 3), dtype='uint8')

batched_input = [] for i in range(images.shape[0]): batched_input.append( { 'image': torch.as_tensor(images[i], device=sam.device).permute(2, 0, 1).contiguous(), 'original_size': images[i].shape[:2], } )

Huh?

LownyCGI avatar Apr 18 '23 18:04 LownyCGI

@HIMANSHUSINGHYANIA Yep, that's what this PR is fixing. 👍

fkodom avatar Apr 18 '23 18:04 fkodom

I do not know if this helps, but I've implemented a working prototype to TorchScript. You can check it out here: https://github.com/csia-pme/djl-image-sam

leonardcser avatar Apr 21 '23 12:04 leonardcser

print("PyTorch version:", torch.version) print("Torchvision version:", torchvision.version) print("CUDA is available:", torch.cuda.is_available())

from segment_anything import sam_model_registry import numpy as np

model_checkpoint = '/path/to/sam_vit_b_01ec64.pth' sam = sam_model_registry"vit_b".to(device='cuda')

images = np.zeros((10, 256, 256, 3), dtype='uint8')

batched_input = [] for i in range(images.shape[0]): batched_input.append( { 'image': torch.as_tensor(images[i], device=sam.device).permute(2, 0, 1).contiguous(), 'original_size': images[i].shape[:2], } )

Huh?

HIMANSHUSINGHYANIA avatar Apr 27 '23 09:04 HIMANSHUSINGHYANIA