accelerated_features
accelerated_features copied to clipboard
Add support for exporting to ONNX.
- Support for exporting xfeat, xfeat+matching models.
- Support for exporting dynamic shapes.
- Support for exporting a specified version of opset.
- Add onnxruntime inference demo.
Model | Inputs | Outputs | Note |
---|---|---|---|
xfeat.onnx |
images | feats, keypoints, heatmaps | Extract image keypoints and features |
xfeat_dualscale.onnx |
images | mkpts, feats, sc | Extract dualscale image keypoints and features. |
matching.onnx |
mkpts0, feats0, sc0, mkpts1, feats1 | matches, batch_indexes | Match dualscale keypoints and features |
xfeat_matching.onnx |
images0, images1 | matches, batch_indexes | End-to-end extraction of features from two sets of images and performing feature matching. |
examples
- FPS is limited by this line, especially when a large number of keypoints are matched https://github.com/verlab/accelerated_features/blob/b2f6b99b85db0672fc5e49df016850370ce8ad57/realtime_demo.py#L239
xfeat_onnxruntime.py
import numpy as np
import onnxruntime as ort
def create_ort_session(model_path, trt_engine_cache_path='trt_engine_cache', trt_engine_cache_prefix='model'):
tmp_ort_session = ort.InferenceSession(model_path, providers=['CPUExecutionProvider'])
# print the input,output names and shapes
for i in range(len(tmp_ort_session.get_inputs())):
print(f"Input name: {tmp_ort_session.get_inputs()[i].name}, shape: {tmp_ort_session.get_inputs()[i].shape}")
for i in range(len(tmp_ort_session.get_outputs())):
print(f"Output name: {tmp_ort_session.get_outputs()[i].name}, shape: {tmp_ort_session.get_outputs()[i].shape}")
providers = [
# The TensorrtExecutionProvider is the fastest.
('TensorrtExecutionProvider', {
'device_id': 0,
'trt_max_workspace_size': 4 * 1024 * 1024 * 1024,
'trt_fp16_enable': True,
'trt_engine_cache_enable': True,
'trt_engine_cache_path': trt_engine_cache_path,
'trt_engine_cache_prefix': trt_engine_cache_prefix,
'trt_dump_subgraphs': False,
'trt_timing_cache_enable': True,
'trt_timing_cache_path': trt_engine_cache_path,
#'trt_builder_optimization_level': 3,
}),
# The CUDAExecutionProvider is slower than PyTorch,
# possibly due to performance issues with large matrix multiplication "cossim = torch.bmm(feats1, feats2.permute(0,2,1))"
# Reducing the top_k value when exporting to ONNX can decrease the matrix size.
('CUDAExecutionProvider', {
'device_id': 0,
'gpu_mem_limit': 4 * 1024 * 1024 * 1024,
}),
('CPUExecutionProvider',{
})
]
ort_session = ort.InferenceSession(model_path, providers=providers)
return ort_session
class XFeat:
def __init__(self, xfeat_model_path='./xfeat_dualscale.onnx', matcher_model_path='./matching.onnx'):
self.xfeat_ort_session = create_ort_session(xfeat_model_path, trt_engine_cache_prefix='xfeat_dualscale')
self.matcher_ort_session = create_ort_session(matcher_model_path, trt_engine_cache_prefix='matching')
# warm up
for i in range(5):
image = np.zeros((640, 640, 3), dtype=np.float32)
self.detectAndCompute(image)
mkpts0 = np.zeros((1, 4800, 2), dtype=np.float32)
feats0 = np.zeros((1, 4800, 64), dtype=np.float32)
sc0 = np.zeros((1, 4800), dtype=np.float32)
mkpts1 = np.zeros((1, 4800, 2), dtype=np.float32)
feats1 = np.zeros((1, 4800, 64), dtype=np.float32)
self.match(mkpts0, feats0, sc0, mkpts1, feats1)
def detectAndCompute(self, image_data, mask=None):
input_array = np.expand_dims(image_data.transpose((2, 0, 1)) , axis=0).astype(np.float32)
inputs = {
self.xfeat_ort_session.get_inputs()[0].name: input_array,
}
mkpts0, feats0, sc = self.xfeat_ort_session.run(None, inputs)
return {
"keypoints": mkpts0,
"descriptors": feats0,
"sc": sc,
}
def match(self, mkpts0, feats0, sc0, mkpts1, feats1):
inputs = {
self.matcher_ort_session.get_inputs()[0].name: mkpts0,
self.matcher_ort_session.get_inputs()[1].name: feats0,
self.matcher_ort_session.get_inputs()[2].name: sc0,
self.matcher_ort_session.get_inputs()[3].name: mkpts1,
self.matcher_ort_session.get_inputs()[4].name: feats1,
}
matches, batch_indexes = self.matcher_ort_session.run(None, inputs)
return matches, batch_indexes
realtime_demo.py patches
diff --git a/realtime_demo.py b/realtime_demo.py
index 6c867fd..8e21ae2 100644
--- a/realtime_demo.py
+++ b/realtime_demo.py
@@ -7,20 +7,18 @@
import cv2
import numpy as np
-import torch
from time import time, sleep
import argparse, sys, tqdm
import threading
-from modules.xfeat import XFeat
def argparser():
parser = argparse.ArgumentParser(description="Configurations for the real-time matching demo.")
parser.add_argument('--width', type=int, default=640, help='Width of the video capture stream.')
parser.add_argument('--height', type=int, default=480, help='Height of the video capture stream.')
parser.add_argument('--max_kpts', type=int, default=3_000, help='Maximum number of keypoints.')
- parser.add_argument('--method', type=str, choices=['ORB', 'SIFT', 'XFeat'], default='XFeat', help='Local feature detection method to use.')
+ parser.add_argument('--method', type=str, choices=['ORB', 'SIFT', 'XFeat', 'XFeat_Ort'], default='XFeat', help='Local feature detection method to use.')
parser.add_argument('--cam', type=int, default=0, help='Webcam device number.')
return parser.parse_args()
@@ -52,6 +50,7 @@ class CVWrapper():
def __init__(self, mtd):
self.mtd = mtd
def detectAndCompute(self, x, mask=None):
+ import torch
return self.mtd.detectAndCompute(torch.tensor(x).permute(2,0,1).float()[None])[0]
class Method:
@@ -65,7 +64,12 @@ def init_method(method, max_kpts):
elif method == "SIFT":
return Method(descriptor=cv2.SIFT_create(max_kpts, contrastThreshold=-1, edgeThreshold=1000), matcher=cv2.BFMatcher(cv2.NORM_L2, crossCheck=True))
elif method == "XFeat":
+ from modules.xfeat import XFeat
return Method(descriptor=CVWrapper(XFeat(top_k = max_kpts)), matcher=XFeat())
+ elif method == "XFeat_Ort":
+ from xfeat_onnxruntime import XFeat
+ xfeat = XFeat()
+ return Method(descriptor=xfeat, matcher=xfeat)
else:
raise RuntimeError("Invalid Method.")
@@ -200,6 +204,12 @@ class MatchingDemo:
if self.args.method in ['SIFT', 'ORB']:
kp1, des1 = self.ref_precomp
kp2, des2 = self.method.descriptor.detectAndCompute(current_frame, None)
+ elif self.args.method in ['XFeat_Ort']:
+ current = self.method.descriptor.detectAndCompute(current_frame)
+ kpts1, descs1, sc1 = self.ref_precomp['keypoints'], self.ref_precomp['descriptors'], self.ref_precomp['sc']
+ kpts2, descs2 = current['keypoints'], current['descriptors']
+ matches, batch_indexes = self.method.matcher.match(kpts1, descs1, sc1, kpts2, descs2)
+ points1, points2 = matches[batch_indexes == 0][..., :2], matches[batch_indexes == 0][..., 2:]
else:
current = self.method.descriptor.detectAndCompute(current_frame)
kpts1, descs1 = self.ref_precomp['keypoints'], self.ref_precomp['descriptors']
Im slightly busy with graduation this week, I will be back to work on this by the weekend
Im slightly busy with graduation this week, I will be back to work on this by the weekend
No problem, thanks for letting me know. Congratulations on your graduation!
Here is a sample python code for running xfeat_matching.onnx with TensorRT API. This example does not include exception handling, input/output validation, etc., and is for reference only.
update 2024-06-24:
- Update
xfeat_dualscale.onnx
, Add support for fewer than 4800 num_keypoints to accommodate low-resolution images - Update
matching.onnx
, Add support for different numbers of two sets of features - Add tensorrt demo codes for
xfeat_dualscale.onnx
andmatching.onnx
import os
import sys
import time
from typing import Optional, List
from functools import reduce
import numpy as np
import cv2
import tensorrt as trt
import cupy as cp # pip install cupy-cuda12x
from cuda import cuda, cudart # pip install cuda-python
from packaging.version import Version
if Version(trt.__version__) >= Version('9.0.0'):
# This is a simple ASCII-art progress monitor comparable to the C++ version in sample_progress_monitor.
class SimpleProgressMonitor(trt.IProgressMonitor):
def __init__(self):
trt.IProgressMonitor.__init__(self)
self._active_phases = {}
self._step_result = True
def phase_start(self, phase_name, parent_phase, num_steps):
try:
if parent_phase is not None:
nbIndents = 1 + self._active_phases[parent_phase]['nbIndents']
else:
nbIndents = 0
self._active_phases[phase_name] = { 'title': phase_name, 'steps': 0, 'num_steps': num_steps, 'nbIndents': nbIndents }
self._redraw()
except KeyboardInterrupt:
# The phase_start callback cannot directly cancel the build, so request the cancellation from within step_complete.
_step_result = False
def phase_finish(self, phase_name):
try:
del self._active_phases[phase_name]
self._redraw(blank_lines=1) # Clear the removed phase.
except KeyboardInterrupt:
_step_result = False
def step_complete(self, phase_name, step):
try:
self._active_phases[phase_name]['steps'] = step
self._redraw()
return self._step_result
except KeyboardInterrupt:
# There is no need to propagate this exception to TensorRT. We can simply cancel the build.
return False
def _redraw(self, *, blank_lines=0):
# The Python curses module is not widely available on Windows platforms.
# Instead, this function uses raw terminal escape sequences. See the sample documentation for references.
def clear_line():
print('\x1B[2K', end='')
def move_to_start_of_line():
print('\x1B[0G', end='')
def move_cursor_up(lines):
print('\x1B[{}A'.format(lines), end='')
def progress_bar(steps, num_steps):
INNER_WIDTH = 10
completed_bar_chars = int(INNER_WIDTH * steps / float(num_steps))
return '[{}{}]'.format(
'=' * completed_bar_chars,
'-' * (INNER_WIDTH - completed_bar_chars))
# Set max_cols to a default of 200 if not run in interactive mode.
max_cols = os.get_terminal_size().columns if sys.stdout.isatty() else 200
move_to_start_of_line()
for phase in self._active_phases.values():
phase_prefix = '{indent}{bar} {title}'.format(
indent = ' ' * phase['nbIndents'],
bar = progress_bar(phase['steps'], phase['num_steps']),
title = phase['title'])
phase_suffix = '{steps}/{num_steps}'.format(**phase)
allowable_prefix_chars = max_cols - len(phase_suffix) - 2
if allowable_prefix_chars < len(phase_prefix):
phase_prefix = phase_prefix[0:allowable_prefix_chars-3] + '...'
clear_line()
print(phase_prefix, phase_suffix)
for line in range(blank_lines):
clear_line()
print()
move_cursor_up(len(self._active_phases) + blank_lines)
sys.stdout.flush()
FP16_ENABLE = True
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
def check_cuda_err(err):
if isinstance(err, cuda.CUresult):
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("Cuda Error: {}".format(err))
if isinstance(err, cudart.cudaError_t):
if err != cudart.cudaError_t.cudaSuccess:
raise RuntimeError("Cuda Runtime Error: {}".format(err))
else:
raise RuntimeError("Unknown error type: {}".format(err))
def cuda_call(call):
err, res = call[0], call[1:]
check_cuda_err(err)
if len(res) == 1:
res = res[0]
return res
class OutputAllocator(trt.IOutputAllocator):
def __init__(self):
trt.IOutputAllocator.__init__(self)
self.buffers = {}
self.shapes = {}
def reallocate_output(self, tensor_name, memory, size, alignment):
output_dtype = cp.dtype(cp.byte)
output = cp.empty(size, output_dtype)
ptr = output.data.ptr
self.buffers[tensor_name] = output
return ptr
def notify_shape(self, tensor_name, shape):
self.shapes[tensor_name] = tuple(shape)
def get_engine(onnx_file_path, engine_file_path=""):
"""Attempts to load a serialized engine if available, otherwise builds a new TensorRT engine and saves it."""
TRT_LOGGER = trt.Logger()
def build_engine():
"""Takes an ONNX file and creates a TensorRT engine to run inference with"""
with trt.Builder(TRT_LOGGER) as builder, builder.create_network(
0 if Version(trt.__version__) >= Version('9.0.0') else EXPLICIT_BATCH
) as network, builder.create_builder_config() as config, trt.OnnxParser(
network, TRT_LOGGER
) as parser, trt.Runtime(
TRT_LOGGER
) as runtime:
# Parse model file
if not os.path.exists(onnx_file_path):
print(
"ONNX file {} not found.".format(onnx_file_path)
)
return None
print("Loading ONNX file from path {}...".format(onnx_file_path))
with open(onnx_file_path, "rb") as model:
print("Beginning ONNX file parsing")
if not parser.parse(model.read()):
print("ERROR: Failed to parse the ONNX file.")
for error in range(parser.num_errors):
print(parser.get_error(error))
return None
print("Completed parsing of ONNX file")
if Version(trt.__version__) >= Version('9.0.0'):
config.progress_monitor = SimpleProgressMonitor()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 4 << 30) # 4GB
# config.max_workspace_size = 1 << 30 # 1GB tensorrt < 8.4
# builder.max_batch_size = 1
for i in range(network.num_inputs):
input_node = network.get_input(i)
print('input ', i, input_node.name, input_node.shape)
assert network.num_inputs == 2, "For xfeat_matching.onnx, only supports two input nodes."
profile = builder.create_optimization_profile()
profile.set_shape(network.get_input(0).name, (1, 3, 64, 64), (8, 3, 640, 640), (32, 3, 1280, 1280))
profile.set_shape(network.get_input(1).name, (1, 3, 64, 64), (8, 3, 640, 640), (32, 3, 1280, 1280))
config.add_optimization_profile(profile)
if builder.platform_has_fast_fp16 and FP16_ENABLE:
print("FP16 mode enabled")
config.set_flag(trt.BuilderFlag.FP16)
print("Building an engine from file {}; this may take a while...".format(onnx_file_path))
plan = builder.build_serialized_network(network, config)
engine = runtime.deserialize_cuda_engine(plan)
print("Completed creating Engine")
with open(engine_file_path, "wb") as f:
f.write(plan)
return engine
if os.path.exists(engine_file_path):
# If a serialized engine exists, use it instead of building an engine.
print("Reading engine from file {}".format(engine_file_path))
try:
with open(engine_file_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
engine = runtime.deserialize_cuda_engine(f.read())
if engine == None:
print("Deserialization of the engine from {} failed. Falling back to building the engine".format(engine_file_path))
return build_engine()
return engine
except Exception as e:
print(e)
print("Deserialization of the engine from {} failed. Falling back to building the engine".format(engine_file_path))
return build_engine()
else:
return build_engine()
# This function is generalized for multiple inputs/outputs for full dimension networks.
# inputs and outputs are expected to be lists of HostDeviceMem objects.
def do_inference(context, engine, inputs_dict, output_allocator, stream):
num_io = engine.num_io_tensors
outputs = []
for i in range(num_io):
name = engine.get_tensor_name(i)
if engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
context.set_tensor_address(name, inputs_dict[name].data.ptr)
context.set_input_shape(name, inputs_dict[name].shape)
context.execute_async_v3(stream_handle=stream)
cuda_call(cudart.cudaStreamSynchronize(stream))
for i in range(num_io):
name = engine.get_tensor_name(i)
if engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT:
output_shape = output_allocator.shapes[name]
output_dtype = cp.dtype(trt.nptype(engine.get_tensor_dtype(name)))
output_size = output_dtype.itemsize * reduce((lambda x, y: x * y), output_shape, 1)
outputs.append(output_allocator.buffers[name][:output_size].view(output_dtype).reshape(output_shape))
return outputs
class XFEAT:
def __init__(self, modelPath, gpu_id=0) -> None:
self.gpu_id = gpu_id
# Init use gpu_id
cudart.cudaSetDevice(self.gpu_id)
engine_file_path = os.path.splitext(modelPath)[0] + ".trt"
self.engine = get_engine(modelPath, engine_file_path)
self.context = self.engine.create_execution_context()
self.stream = cuda_call(cudart.cudaStreamCreate())
self.output_allocator = OutputAllocator()
for i in range(self.engine.num_io_tensors):
name = self.engine.get_tensor_name(i)
if self.engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT:
self.context.set_output_allocator(name, self.output_allocator)
tensor_names = [self.engine.get_tensor_name(i) for i in range(self.engine.num_io_tensors)]
for tensor_name in tensor_names:
print(tensor_name, self.engine.get_tensor_shape(tensor_name))
def inference_preprocessed(self, input_array_0, input_array_1):
cudart.cudaSetDevice(self.gpu_id)
inputs_dict = {
"images0": input_array_0,
"images1": input_array_1,
}
trt_outputs = do_inference(self.context, self.engine, inputs_dict=inputs_dict, output_allocator=self.output_allocator, stream=self.stream)
matches = cp.asnumpy(trt_outputs[0])
batch_indexes = cp.asnumpy(trt_outputs[1])
return matches, batch_indexes
def __call__(self, input_array_0, input_array_1):
input_array_device_0 = cp.asarray(input_array_0, dtype=cp.float32)
input_array_device_1 = cp.asarray(input_array_1, dtype=cp.float32)
return self.inference_preprocessed(input_array_device_0, input_array_device_1)
def __del__(self) -> None:
if hasattr(self, 'stream'):
if cudart != None:
cuda_call(cudart.cudaStreamDestroy(self.stream))
def warp_corners_and_draw_matches(ref_points, dst_points, img1, img2):
# Calculate the Homography matrix
H, mask = cv2.findHomography(ref_points, dst_points, cv2.USAC_MAGSAC, 3.5, maxIters=1_000, confidence=0.999)
mask = mask.flatten()
# Get corners of the first image (image1)
h, w = img1.shape[:2]
corners_img1 = np.array([[0, 0], [w-1, 0], [w-1, h-1], [0, h-1]], dtype=np.float32).reshape(-1, 1, 2)
# Warp corners to the second image (image2) space
warped_corners = cv2.perspectiveTransform(corners_img1, H)
# Draw the warped corners in image2
img2_with_corners = img2.copy()
for i in range(len(warped_corners)):
start_point = tuple(warped_corners[i-1][0].astype(int))
end_point = tuple(warped_corners[i][0].astype(int))
cv2.line(img2_with_corners, start_point, end_point, (0, 255, 0), 4) # Using solid green for corners
# Prepare keypoints and matches for drawMatches function
keypoints1 = [cv2.KeyPoint(p[0], p[1], 5) for p in ref_points]
keypoints2 = [cv2.KeyPoint(p[0], p[1], 5) for p in dst_points]
matches = [cv2.DMatch(i,i,0) for i in range(len(mask)) if mask[i]]
# Draw inlier matches
img_matches = cv2.drawMatches(img1, keypoints1, img2_with_corners, keypoints2, matches, None,
matchColor=(0, 255, 0), flags=2)
return img_matches
if __name__ == "__main__":
onnx_file_path = './models/xfeat_matching.onnx'
input_image_0_path = './images/ref.png'
input_image_1_path = './images/tgt.png'
xfeat = XFEAT(onnx_file_path, 0)
image_0 = cv2.imread(input_image_0_path, cv2.IMREAD_COLOR)
image_1 = cv2.imread(input_image_1_path, cv2.IMREAD_COLOR)
batch_zise = 8 # Psuedo-batch the input images
input_array_0 = np.expand_dims(image_0.transpose(2, 0, 1), axis=0).repeat(batch_zise, axis=0)
input_array_1 = np.expand_dims(image_1.transpose(2, 0, 1), axis=0).repeat(batch_zise, axis=0)
matches, batch_indexes = xfeat(input_array_0, input_array_1)
mkpts_0, mkpts_1 = matches[batch_indexes == 0][..., :2], matches[batch_indexes == 0][..., 2:]
img_matches = warp_corners_and_draw_matches(mkpts_0, mkpts_1, image_0, image_1)
cv2.imshow("Matches", img_matches)
cv2.waitKey(0)
loop = 100
start = time.time()
for i in range(loop):
matches, batch_indexes = xfeat(input_array_0, input_array_1)
end = time.time()
print("Time: ", (end - start) / loop / batch_zise)
print("FPS: ", batch_zise * loop / (end - start))
Hi @acai66 @IamShubhamGupto ,
Thank you very much for providing the ONNX examples in both C++ and Python. They look amazing and incredibly useful!
Are you still waiting for @IamShubhamGupto to review the merge? I am not sure if you both merged the work you did last month, so I am just checking in to see if I should start reviewing the PR.
Once again, thank you @acai66 and @IamShubhamGupto for the ONNX examples. I appreciate your effort to make XFeat deployment much better!
Hi @acai66 @IamShubhamGupto ,
Thank you very much for providing the ONNX examples in both C++ and Python. They look amazing and incredibly useful!
Are you still waiting for @IamShubhamGupto to review the merge? I am not sure if you both merged the work you did last month, so I am just checking in to see if I should start reviewing the PR.
Once again, thank you @acai66 and @IamShubhamGupto for the ONNX examples. I appreciate your effort to make XFeat deployment much better!
Thank you for checking in. I am still waiting for @IamShubhamGupto's review. Unfortunately, I haven't received any response from @IamShubhamGupto regarding last month's work, which might be due to the busy graduation season.
Hi @acai66 @IamShubhamGupto , Thank you very much for providing the ONNX examples in both C++ and Python. They look amazing and incredibly useful! Are you still waiting for @IamShubhamGupto to review the merge? I am not sure if you both merged the work you did last month, so I am just checking in to see if I should start reviewing the PR. Once again, thank you @acai66 and @IamShubhamGupto for the ONNX examples. I appreciate your effort to make XFeat deployment much better!
Thank you for checking in. I am still waiting for @IamShubhamGupto's review. Unfortunately, I haven't received any response from @IamShubhamGupto regarding last month's work, which might be due to the busy graduation season.
Hey @acai66 @guipotje sorry to keep both of you waiting. I did just graduate and was busy with a few conferences and competitions. As for the development on this branch, I believe you should go ahead and merge this branch. I will be back on contributing to this project some time later.
@acai66 Hi, thank you for your contribution. I found that only detectAndComputeDense
is supported. Could you also make it support for detectAndCompute
?
@acai66 Hi, thank you for your contribution. I found that only
detectAndComputeDense
is supported. Could you also make it support fordetectAndCompute
?
commit: https://github.com/verlab/accelerated_features/pull/5/commits/f4a55c14fddbf61ea1f6de83c8a30baacecdc88b
@acai66 Thank you very much!
@acai66 Have you met a problem in DetectAndCompute
that the #Select top-k features
function exceeded the max number when using TRT Execution Provider.
Hey @acai66,
thanks for all your work on the ONNX export of XFeat, it's been very handy. @guipotje recently added the Lighterglue addon matcher so I spent some time to make that available for ONNX export on top of your changes, see the branch here:
https://github.com/stschake/accelerated_features/tree/feature/lighterglue-onnx
The upstream code uses kornia which isn't suitable for ONNX export, so I started with the LightGlue-ONNX implementation and modified it slightly to add things like keypoints normalization directly in the model, in the xfeat tradition.
Thanks guys! Looking forward for onnx version.