TensorRT
TensorRT copied to clipboard
❓ [Question] Is SAM2 supported when compiling with the Dynamo backend on JetPack 6.1 or 6.2?
❓ Question
Will SAM2 be compatible with the Dynamo backend on JetPack 6.1/6.2?
Are there any workarounds for the TensorRT version mismatch?
What you have already tried
Here are my attempts and issues encountered, my device is jetson AGX Orin, I only compile the ImageEncoder (Hiera & FPN which remove position_encoding) of SAM2, the SAM2 code is from https://github.com/chohk88/sam2/tree/torch-trt:
JetPack 6.1 + PyTorch 2.5 (from https://developer.download.nvidia.cn) + Torch-TensorRT 2.5
Tried compiling SAM2 but encountered errors.
Observed that the PyTorch 2.5 documentation does not mention SAM2 support, likely indicating SAM2 is not yet adapted for this version.
JetPack 6.1 + PyTorch 2.6 (from https://pypi.jetson-ai-lab.dev/jp6/cu126) + Torch-TensorRT 2.6
Installed PyTorch 2.6 from jp6/cu126 and Torch-TensorRT 2.6.
Importing torch_tensorrt failed with ModuleNotFoundError: No module named 'tensorrt.plugin'.
Root cause: Torch-TensorRT 2.6 requires TensorRT 10.7, but JetPack 6.1 provides only TensorRT 10.3.
Found no straightforward way to upgrade TensorRT within JetPack 6.1 due to dependency conflicts.
Cross-Platform Attempt: Compile on x86 + Run on JetPack 6.1
Compiled SAM2 on x86 with Torch-TensorRT 2.6 and exported the model.
Tried running it on JetPack 6.1 with Torch-TensorRT 2.5.
Failed unsurprisingly due to serialization version incompatibility between 2.6 and 2.5.
cc @peri044 @chohk88
@AyanamiReiFan I don't know of any workarounds for upgrading TRT 10.3 on Jetpack. That being said, you could give 25.03-py3-igpu a container a try. This container has TRT 10.9 and the corresponding Torch-TRT version. This might work although I haven't tested this yet. In the future, Jetpack 7 will have TRT 10.6+ which could also fix this issue.
The iGPU container should also have a much more recent version of Torch-TRT
@AyanamiReiFan I don't know of any workarounds for upgrading TRT 10.3 on Jetpack. That being said, you could give 25.03-py3-igpu a container a try. This container has TRT 10.9 and the corresponding Torch-TRT version. This might work although I haven't tested this yet. In the future, Jetpack 7 will have TRT 10.6+ which could also fix this issue.
Thanks very much! I will try it later.
@AyanamiReiFan
I have fixed the No module named 'tensorrt.plugin'
Here is the PR merged to main: https://github.com/pytorch/TensorRT/pull/3518
You can Follow the guide on this branch https://github.com/pytorch/TensorRT/blob/b0baba1d9687ad8a8f1db577abd029e38e3555af/docsrc/getting_started/jetpack.rst
meanwhile I will also give a try for SAM2 on jetson.
@AyanamiReiFan I have successfully build and install sam2 in jetson orin.
since https://github.com/pytorch/TensorRT/pull/3524/files has not been merged to main yet. please follow the jetpack guide here for building from lastest: https://github.com/pytorch/TensorRT/blob/61b3480b1576b7b445ce883003e2d2aff5795610/docsrc/getting_started/jetpack.rst
@AyanamiReiFan I have successfully build and install sam2 in jetson orgin.
since https://github.com/pytorch/TensorRT/pull/3524/files has not been merged to main yet. please follow the jetpack guide here for building from lastest: https://github.com/pytorch/TensorRT/blob/61b3480b1576b7b445ce883003e2d2aff5795610/docsrc/getting_started/jetpack.rst
Thank you very much. I've read this JetPack guide. Can I assume that this guide essentially installs PyTorch 2.7 and the not-yet-officially-released Torch-TensorRT 2.8 on JetPack 6.2-based Jetson devices? If so, once Torch-TensorRT 2.8 is officially released, will it be properly compatible with JetPack 6.2 Jetson devices?
Additionally, I’m currently trying to follow these installation steps and will provide feedback as soon as possible.
@AyanamiReiFan I have successfully build and install sam2 in jetson orgin.
since https://github.com/pytorch/TensorRT/pull/3524/files has not been merged to main yet. please follow the jetpack guide here for building from lastest: https://github.com/pytorch/TensorRT/blob/61b3480b1576b7b445ce883003e2d2aff5795610/docsrc/getting_started/jetpack.rst
I build troch-tensorrt from the main branch.
When execute the comand:
python setup.py bdist_wheel --jetpack
I meet this:
Loading: 0 packages loaded
the wheel buid failed.
detail log:
xxxxxxx@ubuntu:~/Develops/env_prepare/TensorRT$ python setup.py bdist_wheel --jetpack
2025/06/02 21:55:43 Downloading https://releases.bazel.build/8.1.1/release/bazel-8.1.1-linux-arm64...
Extracting Bazel installation...
Starting local Bazel server (8.1.1) and connecting to it...
no actions running
no actions running
no actions running
no actions running
no actions running
DEBUG: Rule 'rules_pkg+' indicated that a canonical reproducible form can be obtained by modifying arguments commit = "17c57f46e5c7cd58f893d7960b4fe6fe59bb77b1"
DEBUG: Repository rules_pkg+ instantiated at:
@AyanamiReiFan I have successfully build and install sam2 in jetson orgin.
since https://github.com/pytorch/TensorRT/pull/3524/files has not been merged to main yet. please follow the jetpack guide here for building from lastest: https://github.com/pytorch/TensorRT/blob/61b3480b1576b7b445ce883003e2d2aff5795610/docsrc/getting_started/jetpack.rst
I have identified the reason for the "Loading: 0 packages loaded" issue mentioned above. When configuring the system proxy on Ubuntu, I forgot to set no_proxy, causing traffic that should have gone to the Bazel server to be incorrectly routed to the proxy server.
@AyanamiReiFan I have successfully build and install sam2 in jetson orgin.
since https://github.com/pytorch/TensorRT/pull/3524/files has not been merged to main yet. please follow the jetpack guide here for building from lastest: https://github.com/pytorch/TensorRT/blob/61b3480b1576b7b445ce883003e2d2aff5795610/docsrc/getting_started/jetpack.rst
I got this error when install the whl, the torch2.7 is not supported by main branch?
xxxxxxx@ubuntu:~/Develops/env_prepare/TensorRT/dist$ python -m pip install torch_tensorrt-2.8.0.dev0+727cbd2e9-cp310-cp310-linux_aarch64.whl Defaulting to user installation because normal site-packages is not writeable Looking in indexes: https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple Processing ./torch_tensorrt-2.8.0.dev0+727cbd2e9-cp310-cp310-linux_aarch64.whl INFO: pip is looking at multiple versions of torch-tensorrt to determine which version is compatible with other requirements. This could take a while. ERROR: Could not find a version that satisfies the requirement torch<2.9.0,>=2.8.0.dev (from torch-tensorrt) (from versions: 1.10.2, 1.11.0, 1.12.0, 1.12.1, 1.13.0, 1.13.1, 2.0.0, 2.0.1, 2.1.0, 2.1.1, 2.1.2, 2.2.0, 2.2.1, 2.2.2, 2.3.0, 2.3.1, 2.4.0, 2.4.1, 2.5.0, 2.5.1, 2.6.0, 2.7.0) ERROR: No matching distribution found for torch<2.9.0,>=2.8.0.dev
@AyanamiReiFan Did you use the latest code from this branch: https://github.com/pytorch/TensorRT/tree/lluo/jetson_build
in the pyproject, if your environment is tegra, it should pull "torch>=2.7.0,<2.8.0; 'tegra' in platform_release"
https://github.com/pytorch/TensorRT/blob/61b3480b1576b7b445ce883003e2d2aff5795610/pyproject.toml#L13
@AyanamiReiFan Did you use the latest code from this branch: https://github.com/pytorch/TensorRT/tree/lluo/jetson_build
in the pyproject, if your environment is tegra, it should pull "torch>=2.7.0,<2.8.0; 'tegra' in platform_release"
Line 13 in 61b3480
"torch>=2.8.0.dev,<2.9.0; 'tegra' not in platform_release",
I use the main branch, not this branch.
I add --no-deps and it seems work.
the script examples/dynamo/torch_compile_resnet_example.py seems run correctly.
I will try rum sam2 in both main branch and your branch later.
@AyanamiReiFan Did you use the latest code from this branch: https://github.com/pytorch/TensorRT/tree/lluo/jetson_build
in the pyproject, if your environment is tegra, it should pull "torch>=2.7.0,<2.8.0; 'tegra' in platform_release"
Line 13 in 61b3480
"torch>=2.8.0.dev,<2.9.0; 'tegra' not in platform_release",
When using your branch, I could compile and run SAM2-Large, but the image quality is significantly worse than in the examples. Additionally, when attempting to use base+ or tiny models, compilation fails with errors:
utils.cpp:2468: CHECK(output_shape.size() == rep_vector.size()) failed.
ERROR:torch_tensorrt [TensorRT Conversion Context]:Error Code: 9: Skipping tactic 0x0000000000000000 due to exception No Myelin Error exists
utils.cpp:2468: CHECK(output_shape.size() == rep_vector.size()) failed.
ERROR:torch_tensorrt [TensorRT Conversion Context]:Error Code: 9: Skipping tactic 0x0000000000000000 due to exception No Myelin Error exists
utils.cpp:2468: CHECK(output_shape.size() == rep_vector.size()) failed.
ERROR:torch_tensorrt [TensorRT Conversion Context]:Error Code: 9: Skipping tactic 0x0000000000000000 due to exception No Myelin Error exists
ERROR:torch_tensorrt [TensorRT Conversion Context]:IBuilder::buildSerializedNetwork: Error Code 10: Internal Error (Could not find any implementation for node {ForeignNode[[SLICE]-[aten_ops.slice.Tensor]-[/slice_18]...[SHUFFLE]-[aten_ops.squeeze.dim]-[/squeeze_63] + [SHUFFLE]-[aten_ops._reshape_copy.default]-[/_reshape_copy_424] + [SHUFFLE]-[aten_ops.permute.default]-[/permute_246]]}.)
Traceback (most recent call last):
File "/home/xxxxxxx/Develops/sam2-torch-trt/export_sam2.py", line 257, in <module>
trt_model = torch_tensorrt.dynamo.compile(
File "/home/xxxxxxx/.local/lib/python3.10/site-packages/torch_tensorrt/dynamo/_compiler.py", line 712, in compile
trt_gm = compile_module(
File "/home/xxxxxxx/.local/lib/python3.10/site-packages/torch_tensorrt/dynamo/_compiler.py", line 918, in compile_module
trt_module = convert_module(
File "/home/xxxxxxx/.local/lib/python3.10/site-packages/torch_tensorrt/dynamo/conversion/_conversion.py", line 90, in convert_module
interpreter_result = interpret_module_to_result(
File "/home/xxxxxxx/.local/lib/python3.10/site-packages/torch_tensorrt/dynamo/conversion/_conversion.py", line 69, in interpret_module_to_result
interpreter_result = interpreter.run()
File "/home/xxxxxxx/.local/lib/python3.10/site-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 739, in run
assert serialized_engine
AssertionError
The experimental code I used is from there and I remove the line 38:
from sam_components import SAM2FullModel
Resulting images:
@AyanamiReiFan Did you use the latest code from this branch: https://github.com/pytorch/TensorRT/tree/lluo/jetson_build
in the pyproject, if your environment is tegra, it should pull "torch>=2.7.0,<2.8.0; 'tegra' in platform_release"
Line 13 in 61b3480
"torch>=2.8.0.dev,<2.9.0; 'tegra' not in platform_release",
After making some modifications, I was able to successfully compile Hiera-Tiny models (since I only need the Hiera part of SAM2). However, I'm not entirely sure which specific change(s) in my modifications actually resolved the issue. Below are the modification details—I hope this might be helpful for your development work.
- I modified
hieradet.pyto make thepos_embedbeen cached during model initialization
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import logging
from functools import partial
from typing import List, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from iopath.common.file_io import g_pathmgr
from sam2.modeling.backbones.utils import (
PatchEmbed,
window_partition,
window_unpartition,
)
from sam2.modeling.sam2_utils import DropPath, MLP
def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor:
if pool is None:
return x
# (B, H, W, C) -> (B, C, H, W)
x = x.permute(0, 3, 1, 2)
x = pool(x)
# (B, C, H', W') -> (B, H', W', C)
x = x.permute(0, 2, 3, 1)
if norm:
x = norm(x)
return x
class MultiScaleAttention(nn.Module):
def __init__(
self,
dim: int,
dim_out: int,
num_heads: int,
q_pool: nn.Module = None,
):
super().__init__()
self.dim = dim
self.dim_out = dim_out
self.num_heads = num_heads
self.q_pool = q_pool
self.qkv = nn.Linear(dim, dim_out * 3)
self.proj = nn.Linear(dim_out, dim_out)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, H, W, _ = x.shape
# qkv with shape (B, H * W, 3, nHead, C)
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)
# q, k, v with shape (B, H * W, nheads, C)
q, k, v = torch.unbind(qkv, 2)
# Q pooling (for downsample at stage changes)
if self.q_pool:
q = do_pool(q.reshape(B, H, W, -1), self.q_pool)
H, W = q.shape[1:3] # downsampled shape
q = q.reshape(B, H * W, self.num_heads, -1)
# Torch's SDPA expects [B, nheads, H*W, C] so we transpose
x = F.scaled_dot_product_attention(
q.transpose(1, 2),
k.transpose(1, 2),
v.transpose(1, 2),
)
# Transpose back
x = x.transpose(1, 2)
x = x.reshape(B, H, W, -1)
x = self.proj(x)
return x
class MultiScaleBlock(nn.Module):
def __init__(
self,
dim: int,
dim_out: int,
num_heads: int,
mlp_ratio: float = 4.0,
drop_path: float = 0.0,
norm_layer: Union[nn.Module, str] = "LayerNorm",
q_stride: Tuple[int, int] = None,
act_layer: nn.Module = nn.GELU,
window_size: int = 0,
):
super().__init__()
if isinstance(norm_layer, str):
norm_layer = partial(getattr(nn, norm_layer), eps=1e-6)
self.dim = dim
self.dim_out = dim_out
self.norm1 = norm_layer(dim)
self.window_size = window_size
self.pool, self.q_stride = None, q_stride
if self.q_stride:
self.pool = nn.MaxPool2d(
kernel_size=q_stride, stride=q_stride, ceil_mode=False
)
self.attn = MultiScaleAttention(
dim,
dim_out,
num_heads=num_heads,
q_pool=self.pool,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim_out)
self.mlp = MLP(
dim_out,
int(dim_out * mlp_ratio),
dim_out,
num_layers=2,
activation=act_layer,
)
if dim != dim_out:
self.proj = nn.Linear(dim, dim_out)
def forward(self, x: torch.Tensor) -> torch.Tensor:
shortcut = x # B, H, W, C
x = self.norm1(x)
# Skip connection
if self.dim != self.dim_out:
shortcut = do_pool(self.proj(x), self.pool)
# Window partition
window_size = self.window_size
if window_size > 0:
H, W = x.shape[1], x.shape[2]
x, pad_hw = window_partition(x, window_size)
# Window Attention + Q Pooling (if stage change)
x = self.attn(x)
if self.q_stride:
# Shapes have changed due to Q pooling
window_size = self.window_size // self.q_stride[0]
H, W = shortcut.shape[1:3]
pad_h = (window_size - H % window_size) % window_size
pad_w = (window_size - W % window_size) % window_size
pad_hw = (H + pad_h, W + pad_w)
# Reverse window partition
if self.window_size > 0:
x = window_unpartition(x, window_size, pad_hw, (H, W))
x = shortcut + self.drop_path(x)
# MLP
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class Hiera(nn.Module):
"""
Reference: https://arxiv.org/abs/2306.00989
"""
def __init__(
self,
embed_dim: int = 96, # initial embed dim
num_heads: int = 1, # initial number of heads
drop_path_rate: float = 0.0, # stochastic depth
q_pool: int = 3, # number of q_pool stages
q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages
stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage
dim_mul: float = 2.0, # dim_mul factor at stage shift
head_mul: float = 2.0, # head_mul factor at stage shift
window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
# window size per stage, when not using global att.
window_spec: Tuple[int, ...] = (
8,
4,
14,
7,
),
# global attn in these blocks
global_att_blocks: Tuple[int, ...] = (
12,
16,
20,
),
weights_path=None,
return_interm_layers=True, # return feats from every stage
input_size=None,
):
super().__init__()
assert len(stages) == len(window_spec)
self.window_spec = window_spec
depth = sum(stages)
self.q_stride = q_stride
self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
assert 0 <= q_pool <= len(self.stage_ends[:-1])
self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool]
self.return_interm_layers = return_interm_layers
self.patch_embed = PatchEmbed(
embed_dim=embed_dim,
)
# Which blocks have global att?
self.global_att_blocks = global_att_blocks
# Windowed positional embedding (https://arxiv.org/abs/2311.05613)
self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
self.pos_embed = nn.Parameter(
torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size)
)
self.pos_embed_window = nn.Parameter(
torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0])
)
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, depth)
] # stochastic depth decay rule
cur_stage = 1
self.blocks = nn.ModuleList()
for i in range(depth):
dim_out = embed_dim
# lags by a block, so first block of
# next stage uses an initial window size
# of previous stage and final window size of current stage
window_size = self.window_spec[cur_stage - 1]
if self.global_att_blocks is not None:
window_size = 0 if i in self.global_att_blocks else window_size
if i - 1 in self.stage_ends:
dim_out = int(embed_dim * dim_mul)
num_heads = int(num_heads * head_mul)
cur_stage += 1
block = MultiScaleBlock(
dim=embed_dim,
dim_out=dim_out,
num_heads=num_heads,
drop_path=dpr[i],
q_stride=self.q_stride if i in self.q_pool_blocks else None,
window_size=window_size,
)
embed_dim = dim_out
self.blocks.append(block)
self.channel_list = (
[self.blocks[i].dim_out for i in self.stage_ends[::-1]]
if return_interm_layers
else [self.blocks[-1].dim_out]
)
if weights_path is not None:
with g_pathmgr.open(weights_path, "rb") as f:
chkpt = torch.load(f, map_location="cpu")
logging.info("loading Hiera", self.load_state_dict(chkpt, strict=False))
self.input_size = input_size
if self.input_size is not None:
self.register_buffer(
"resized_pos_embed",
self._get_pos_embed((self.input_size[0] // 4, self.input_size[1] // 4)).clone().detach(),
persistent=False
)
def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
h, w = hw
window_embed = self.pos_embed_window
pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
pos_embed = pos_embed + window_embed.tile(
[x // y for x, y in zip(pos_embed.shape, window_embed.shape)]
)
pos_embed = pos_embed.permute(0, 2, 3, 1)
return pos_embed
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
x = self.patch_embed(x)
# x: (B, H, W, C)
# Add pos embed
if self.input_size is not None:
x = x + self.resized_pos_embed
else:
x = x + self._get_pos_embed(x.shape[1:3])
outputs = []
for i, blk in enumerate(self.blocks):
x = blk(x)
if (i == self.stage_ends[-1]) or (
i in self.stage_ends and self.return_interm_layers
):
feats = x.permute(0, 3, 1, 2)
outputs.append(feats)
return outputs
def get_layer_id(self, layer_name):
# https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
num_layers = self.get_num_layers()
if layer_name.find("rel_pos") != -1:
return num_layers + 1
elif layer_name.find("pos_embed") != -1:
return 0
elif layer_name.find("patch_embed") != -1:
return 0
elif layer_name.find("blocks") != -1:
return int(layer_name.split("blocks")[1].split(".")[1]) + 1
else:
return num_layers + 1
def get_num_layers(self) -> int:
return len(self.blocks)
- I change the comple code, this is my compile code
import os.path
import time
import random
import torch
import torch_tensorrt
import tqdm
from torchvision import transforms
from PIL import Image
from sam2.build_sam import build_sam2
def loop(model, input_shape=(1024, 1024), precision=torch.float32):
with torch.no_grad():
for _ in range(25):
src = torch.randn((1, 3, *input_shape), device='cuda', dtype=precision)
model(src)
torch.cuda.synchronize()
start_time = time.time()
for _ in tqdm.tqdm(range(5000)):
src = torch.randn((1, 3, *input_shape), device='cuda', dtype=precision)
model(src)
torch.cuda.synchronize()
print(5000 / (time.time() - start_time))
def pil_to_tensor(image, target_size=(224, 224)):
# 定义预处理流程
preprocess = transforms.Compose([
transforms.Resize(target_size),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# 应用预处理并添加批次维度
input_tensor = preprocess(image).unsqueeze(0)
return input_tensor
def check_precision(new_model, origin_model, input_shape=(1024, 1024), precision=torch.float32,
test_image_root=None):
all_image_list = []
if test_image_root is not None:
all_image_list = [os.path.join(test_image_root, each) for each in os.listdir(test_image_root)]
for i in range(10):
if all_image_list is None:
dummy_input = torch.randn((1, 3, *input_shape), device='cuda', dtype=precision)
else:
image_name = random.choice(all_image_list)
real_image = Image.open(image_name)
print(os.path.basename(image_name))
dummy_input = pil_to_tensor(real_image, input_shape).to(device='cuda', dtype=precision)
print(dummy_input.dtype)
new_model_result = new_model(dummy_input)
origin_model_result = origin_model(dummy_input.float())
for new_feature_map, old_feature_map in zip(new_model_result, origin_model_result):
# old_feature_map = old_feature_map.to(precision)
print(old_feature_map.shape)
abs_diff = (new_feature_map - old_feature_map).abs()
mean_diff = abs_diff.mean().item()
max_diff = abs_diff.max().item()
# rel_diff = abs_diff / (old_feature_map.abs() + 1e-7) # 避免除零
# mean_rel_diff = rel_diff.mean().item()
# max_rel_diff = rel_diff.max().item()
cos_sim = torch.cosine_similarity(new_feature_map, old_feature_map, dim=1)
mean_cos = cos_sim.mean().item()
min_cos = cos_sim.min().item()
print(mean_diff, max_diff, mean_cos, min_cos)
@torch.no_grad()
def main():
sam2_config = 'configs/sam2.1_only_encoder/sam2.1_hiera_t.yaml'
sam2_ckpt = './pretrained_checkpoint/sam2.1_hiera_tiny.pt'
raw_sam2_hiera = build_sam2(sam2_config, ckpt_path=sam2_ckpt, apply_postprocessing=False, removed_prefix='image_encoder.trunk.')
raw_sam2_hiera.eval()
sam2_hiera = build_sam2(sam2_config, ckpt_path=sam2_ckpt, apply_postprocessing=False, removed_prefix='image_encoder.trunk.')
sam2_hiera.eval()
inference_shape = (672, 896)
real_check_data_dir = './test_image'
# compile_type = None
compile_type = 'dynamo'
history_compiled_model = "tiny.ep"
# compile_type = 'tensorrt'
# history_compiled_model = "tiny.ts"
if compile_type == 'dynamo':
if os.path.exists(history_compiled_model):
sam2_hiera = torch_tensorrt.load(history_compiled_model).module()
print('load history compiled ep model')
else:
if real_check_data_dir is not None:
all_image_list = [os.path.join(real_check_data_dir, each) for each in os.listdir(real_check_data_dir)]
dummy_input = pil_to_tensor(
Image.open(random.choice(all_image_list)), inference_shape
).to(device='cuda', dtype=torch.float16)
else:
dummy_input = torch.randn((1, 3, *inference_shape), device='cuda', dtype=torch.float16)
# sam2_hiera.half()
print('start compile')
exp_program = torch.export.export(
sam2_hiera.half(),
(dummy_input,),
strict=True
)
print('finish compile stage 1')
dummy_input = torch_tensorrt.Input(
shape_mode=0,
shape=(1, 3, *inference_shape),
dtype=torch_tensorrt.dtype.float16,
)
sam2_hiera = torch_tensorrt.dynamo.compile(
exp_program,
inputs=(dummy_input,),
min_block_size=1,
enabled_precisions={torch.float16},
use_fp32_acc=True,
optimization_level=5,
# device=torch_tensorrt.Device("dla:0", allow_gpu_fallback=True),
)
print('finish compile stage 2')
torch_tensorrt.save(sam2_hiera, history_compiled_model, inputs=(dummy_input,))
# sam2_hiera = torch.export.load(history_compiled_model).module()
sam2_hiera = torch_tensorrt.load(history_compiled_model).module()
print('finish compile')
# trt_out = sam2_hiera(dummy_input)
elif compile_type == 'tensorrt':
if os.path.exists(history_compiled_model):
sam2_hiera = torch_tensorrt.load(history_compiled_model)
print('load history compiled ts model')
else:
print('start compile')
dummy_input = torch.randn((1, 3, *inference_shape), device='cuda', dtype=torch.float16)
scripted = torch.jit.trace(
sam2_hiera.half(),
dummy_input
)
scripted = torch.jit.freeze(scripted)
print('finish compile stage 1')
# 静态编译
sam2_hiera = torch_tensorrt.compile(
scripted,
# sam2_hiera.half(),
ir='torchscript',
inputs=[torch_tensorrt.Input(shape=dummy_input.shape, dtype=torch.float16)],
enabled_precisions={torch.float16},
workspace_size=1 << 30,
truncate_long_and_double=True
)
print('finish compile stage 2')
torch_tensorrt.save(sam2_hiera, history_compiled_model, output_format="torchscript", inputs=(dummy_input,))
sam2_hiera = torch_tensorrt.load(history_compiled_model)
print('finish compile')
else:
assert compile_type is None
sam2_hiera = sam2_hiera.half()
# print(sam2_hiera)
check_precision(sam2_hiera, raw_sam2_hiera, input_shape=inference_shape,
precision=torch.float16, test_image_root=real_check_data_dir)
loop(sam2_hiera, input_shape=inference_shape, precision=torch.float16)
if __name__ == '__main__':
main()
- (Not Important) I modified
build_sam.pyto only load weight for Hiera.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
import torch
from hydra import compose
from hydra.utils import instantiate
from omegaconf import OmegaConf
import sam2
# Check if the user is running Python from the parent directory of the sam2 repo
# (i.e. the directory where this repo is cloned into) -- this is not supported since
# it could shadow the sam2 package and cause issues.
if os.path.isdir(os.path.join(sam2.__path__[0], "sam2")):
# If the user has "sam2/sam2" in their path, they are likey importing the repo itself
# as "sam2" rather than importing the "sam2" python package (i.e. "sam2/sam2" directory).
# This typically happens because the user is running Python from the parent directory
# that contains the sam2 repo they cloned.
raise RuntimeError(
"You're likely running Python from the parent directory of the sam2 repository "
"(i.e. the directory where https://github.com/facebookresearch/sam2 is cloned into). "
"This is not supported since the `sam2` Python package could be shadowed by the "
"repository name (the repository is also named `sam2` and contains the Python package "
"in `sam2/sam2`). Please run Python from another directory (e.g. from the repo dir "
"rather than its parent dir, or from your home directory) after installing SAM 2."
)
HF_MODEL_ID_TO_FILENAMES = {
"facebook/sam2-hiera-tiny": (
"configs/sam2/sam2_hiera_t.yaml",
"sam2_hiera_tiny.pt",
),
"facebook/sam2-hiera-small": (
"configs/sam2/sam2_hiera_s.yaml",
"sam2_hiera_small.pt",
),
"facebook/sam2-hiera-base-plus": (
"configs/sam2/sam2_hiera_b+.yaml",
"sam2_hiera_base_plus.pt",
),
"facebook/sam2-hiera-large": (
"configs/sam2/sam2_hiera_l.yaml",
"sam2_hiera_large.pt",
),
"facebook/sam2.1-hiera-tiny": (
"configs/sam2.1/sam2.1_hiera_t.yaml",
"sam2.1_hiera_tiny.pt",
),
"facebook/sam2.1-hiera-small": (
"configs/sam2.1/sam2.1_hiera_s.yaml",
"sam2.1_hiera_small.pt",
),
"facebook/sam2.1-hiera-base-plus": (
"configs/sam2.1/sam2.1_hiera_b+.yaml",
"sam2.1_hiera_base_plus.pt",
),
"facebook/sam2.1-hiera-large": (
"configs/sam2.1/sam2.1_hiera_l.yaml",
"sam2.1_hiera_large.pt",
),
}
def build_sam2(
config_file,
ckpt_path=None,
device="cuda",
mode="eval",
hydra_overrides_extra=[],
apply_postprocessing=True,
removed_prefix=None,
**kwargs,
):
if apply_postprocessing:
hydra_overrides_extra = hydra_overrides_extra.copy()
hydra_overrides_extra += [
# dynamically fall back to multi-mask if the single mask is not stable
"++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
]
# Read config and init model
cfg = compose(config_name=config_file, overrides=hydra_overrides_extra)
OmegaConf.resolve(cfg)
model = instantiate(cfg.model, _recursive_=True)
_load_checkpoint(model, ckpt_path, removed_prefix)
model = model.to(device)
if mode == "eval":
model.eval()
return model
def build_sam2_video_predictor(
config_file,
ckpt_path=None,
device="cuda",
mode="eval",
hydra_overrides_extra=[],
apply_postprocessing=True,
**kwargs,
):
hydra_overrides = [
"++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
]
if apply_postprocessing:
hydra_overrides_extra = hydra_overrides_extra.copy()
hydra_overrides_extra += [
# dynamically fall back to multi-mask if the single mask is not stable
"++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
# the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking
"++model.binarize_mask_from_pts_for_mem_enc=true",
# fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution)
"++model.fill_hole_area=8",
]
hydra_overrides.extend(hydra_overrides_extra)
# Read config and init model
cfg = compose(config_name=config_file, overrides=hydra_overrides)
OmegaConf.resolve(cfg)
model = instantiate(cfg.model, _recursive_=True)
_load_checkpoint(model, ckpt_path)
model = model.to(device)
if mode == "eval":
model.eval()
return model
def _hf_download(model_id):
from huggingface_hub import hf_hub_download
config_name, checkpoint_name = HF_MODEL_ID_TO_FILENAMES[model_id]
ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name)
return config_name, ckpt_path
def build_sam2_hf(model_id, **kwargs):
config_name, ckpt_path = _hf_download(model_id)
return build_sam2(config_file=config_name, ckpt_path=ckpt_path, **kwargs)
def build_sam2_video_predictor_hf(model_id, **kwargs):
config_name, ckpt_path = _hf_download(model_id)
return build_sam2_video_predictor(
config_file=config_name, ckpt_path=ckpt_path, **kwargs
)
def _load_checkpoint(model, ckpt_path, removed_prefix=None):
if ckpt_path is not None:
sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"]
if removed_prefix is not None:
pos = len(removed_prefix)
sd = {
k[pos:]: v for k, v in sd.items() if k.startswith(removed_prefix)
}
missing_keys, unexpected_keys = model.load_state_dict(sd)
if missing_keys:
logging.error(missing_keys)
raise RuntimeError()
if unexpected_keys:
logging.error(unexpected_keys)
raise RuntimeError()
logging.info("Loaded checkpoint sucessfully")
- (Not Important) I modified
sam2.1_hiera_t.yamlto only build Hiera.
# @package _global_
# Model
model:
_target_: sam2.modeling.backbones.hieradet.Hiera
embed_dim: 96
num_heads: 1
stages: [1, 2, 7, 2]
global_att_blocks: [5, 7, 9]
window_pos_embed_bkg_spatial_size: [7, 7]
input_size: [672, 896]
@AyanamiReiFan I have successfully build and install sam2 in jetson orgin. since https://github.com/pytorch/TensorRT/pull/3524/files has not been merged to main yet. please follow the jetpack guide here for building from lastest: https://github.com/pytorch/TensorRT/blob/61b3480b1576b7b445ce883003e2d2aff5795610/docsrc/getting_started/jetpack.rst
I have identified the reason for the "Loading: 0 packages loaded" issue mentioned above. When configuring the system proxy on Ubuntu, I forgot to set no_proxy, causing traffic that should have gone to the Bazel server to be incorrectly routed to the proxy server.
@AyanamiReiFan could you please give some insight how did you solve the "Loading: 0 packages loaded" error? I'm stuck with this part even after trying unset. also did you try installing on Jetpack 6.2?
@AyanamiReiFan I have successfully build and install sam2 in jetson orgin. since https://github.com/pytorch/TensorRT/pull/3524/files has not been merged to main yet. please follow the jetpack guide here for building from lastest: https://github.com/pytorch/TensorRT/blob/61b3480b1576b7b445ce883003e2d2aff5795610/docsrc/getting_started/jetpack.rst
I have identified the reason for the "Loading: 0 packages loaded" issue mentioned above. When configuring the system proxy on Ubuntu, I forgot to set no_proxy, causing traffic that should have gone to the Bazel server to be incorrectly routed to the proxy server.
@AyanamiReiFan could you please give some insight how did you solve the "Loading: 0 packages loaded" error? I'm stuck with this part even after trying
unset. also did you try installing on Jetpack 6.2?
I installed it on Jetpack 6.2 on a Jetson Orin Developer Kit (64GB). I previously encountered the "Loading: 0 packages loaded" error because I had set a system proxy on Ubuntu via the command line but forgot to exclude local addresses from the proxy. As a result, requests to addresses like 127.0.0.1 were also being routed through my proxy.
Based on the logs, I deduced that Bazel likely starts a local server to run, and because the client requests were being misrouted, the server wasn't receiving them. Therefore, I modified my proxy configuration by adding the no_proxy setting. After this change, I no longer encountered the "Loading: 0 packages loaded" issue.