segment-anything
segment-anything copied to clipboard
Make `ImageEncoderViT` compilable with `torch.jit.script`
It would be very convenient if ImageEncoderViT
was compilable with torch.jit.script
.
Issue
Load the Sam
model, and attempt to compile with torch.jit.script
:
import torch
from segment_anything import sam_model_registry
model_type = "vit_b"
sam = sam_model_registry[model_type](checkpoint="...")
scripted = torch.jit.script(sam.image_encoder)
which raises the following error:
RuntimeError:
pad_hw is not defined in the false branch:
File "/home/fodom_plainsight_ai/.pyenv/versions/smartpoly-sam/lib/python3.10/site-packages/segment_anything/modeling/image_encoder.py", line 170
x = self.norm1(x)
# Window partition
if self.window_size > 0:
~~~~~~~~~~~~~~~~~~~~~~~~
H, W = x.shape[1], x.shape[2]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
x, pad_hw = window_partition(x, self.window_size)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
x = self.attn(x)
and was used here:
File "/home/fodom_plainsight_ai/.pyenv/versions/smartpoly-sam/lib/python3.10/site-packages/segment_anything/modeling/image_encoder.py", line 177
# Reverse window partition
if self.window_size > 0:
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
~~~~~~ <--- HERE
x = shortcut + x
Solution
A simple fix is to avoid writing if self.window_size > 0:
in two separate places.
print("PyTorch version:", torch.version) print("Torchvision version:", torchvision.version) print("CUDA is available:", torch.cuda.is_available())
from segment_anything import sam_model_registry import numpy as np
model_checkpoint = '/path/to/sam_vit_b_01ec64.pth' sam = sam_model_registry"vit_b".to(device='cuda')
images = np.zeros((10, 256, 256, 3), dtype='uint8')
batched_input = [] for i in range(images.shape[0]): batched_input.append( { 'image': torch.as_tensor(images[i], device=sam.device).permute(2, 0, 1).contiguous(), 'original_size': images[i].shape[:2], } )
pad_hw is not defined in the false branch: File "/home/fodom_plainsight_ai/.pyenv/versions/smartpoly-sam/lib/python3.10/site-packages/segment_anything/modeling/image_encoder.py", line 170 x = self.norm1(x) # Window partition if self.window_size > 0: ~~~~~~~~~~~~~~~~~~~~~~~~ H, W = x.shape[1], x.shape[2] ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ x, pad_hw = window_partition(x, self.window_size) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
x = self.attn(x)
and was used here: File "/home/fodom_plainsight_ai/.pyenv/versions/smartpoly-sam/lib/python3.10/site-packages/segment_anything/modeling/image_encoder.py", line 177 # Reverse window partition if self.window_size > 0: x = window_unpartition(x, self.window_size, pad_hw, (H, W))
print("PyTorch version:", torch.version) print("Torchvision version:", torchvision.version) print("CUDA is available:", torch.cuda.is_available())
from segment_anything import sam_model_registry import numpy as np
model_checkpoint = '/path/to/sam_vit_b_01ec64.pth' sam = sam_model_registry"vit_b".to(device='cuda')
images = np.zeros((10, 256, 256, 3), dtype='uint8')
batched_input = [] for i in range(images.shape[0]): batched_input.append( { 'image': torch.as_tensor(images[i], device=sam.device).permute(2, 0, 1).contiguous(), 'original_size': images[i].shape[:2], } )
Huh?
@HIMANSHUSINGHYANIA Yep, that's what this PR is fixing. 👍
I do not know if this helps, but I've implemented a working prototype to TorchScript. You can check it out here: https://github.com/csia-pme/djl-image-sam
print("PyTorch version:", torch.version) print("Torchvision version:", torchvision.version) print("CUDA is available:", torch.cuda.is_available())
from segment_anything import sam_model_registry import numpy as np
model_checkpoint = '/path/to/sam_vit_b_01ec64.pth' sam = sam_model_registry"vit_b".to(device='cuda')
images = np.zeros((10, 256, 256, 3), dtype='uint8')
batched_input = [] for i in range(images.shape[0]): batched_input.append( { 'image': torch.as_tensor(images[i], device=sam.device).permute(2, 0, 1).contiguous(), 'original_size': images[i].shape[:2], } )
Huh?