Add support for exporting to ONNX.
- Support for exporting xfeat, xfeat+matching models.
- Support for exporting dynamic shapes.
- Support for exporting a specified version of opset.
- Add onnxruntime inference demo.
| Model | Inputs | Outputs | Note |
|---|---|---|---|
xfeat.onnx |
images | feats, keypoints, heatmaps | Extract image keypoints and features |
xfeat_dualscale.onnx |
images | mkpts, feats, sc | Extract dualscale image keypoints and features. |
matching.onnx |
mkpts0, feats0, sc0, mkpts1, feats1 | matches, batch_indexes | Match dualscale keypoints and features |
xfeat_matching.onnx |
images0, images1 | matches, batch_indexes | End-to-end extraction of features from two sets of images and performing feature matching. |
examples
- FPS is limited by this line, especially when a large number of keypoints are matched https://github.com/verlab/accelerated_features/blob/b2f6b99b85db0672fc5e49df016850370ce8ad57/realtime_demo.py#L239
xfeat_onnxruntime.py
import numpy as np
import onnxruntime as ort
def create_ort_session(model_path, trt_engine_cache_path='trt_engine_cache', trt_engine_cache_prefix='model'):
tmp_ort_session = ort.InferenceSession(model_path, providers=['CPUExecutionProvider'])
# print the input,output names and shapes
for i in range(len(tmp_ort_session.get_inputs())):
print(f"Input name: {tmp_ort_session.get_inputs()[i].name}, shape: {tmp_ort_session.get_inputs()[i].shape}")
for i in range(len(tmp_ort_session.get_outputs())):
print(f"Output name: {tmp_ort_session.get_outputs()[i].name}, shape: {tmp_ort_session.get_outputs()[i].shape}")
providers = [
# The TensorrtExecutionProvider is the fastest.
('TensorrtExecutionProvider', {
'device_id': 0,
'trt_max_workspace_size': 4 * 1024 * 1024 * 1024,
'trt_fp16_enable': True,
'trt_engine_cache_enable': True,
'trt_engine_cache_path': trt_engine_cache_path,
'trt_engine_cache_prefix': trt_engine_cache_prefix,
'trt_dump_subgraphs': False,
'trt_timing_cache_enable': True,
'trt_timing_cache_path': trt_engine_cache_path,
#'trt_builder_optimization_level': 3,
}),
# The CUDAExecutionProvider is slower than PyTorch,
# possibly due to performance issues with large matrix multiplication "cossim = torch.bmm(feats1, feats2.permute(0,2,1))"
# Reducing the top_k value when exporting to ONNX can decrease the matrix size.
('CUDAExecutionProvider', {
'device_id': 0,
'gpu_mem_limit': 4 * 1024 * 1024 * 1024,
}),
('CPUExecutionProvider',{
})
]
ort_session = ort.InferenceSession(model_path, providers=providers)
return ort_session
class XFeat:
def __init__(self, xfeat_model_path='./xfeat_dualscale.onnx', matcher_model_path='./matching.onnx'):
self.xfeat_ort_session = create_ort_session(xfeat_model_path, trt_engine_cache_prefix='xfeat_dualscale')
self.matcher_ort_session = create_ort_session(matcher_model_path, trt_engine_cache_prefix='matching')
# warm up
for i in range(5):
image = np.zeros((640, 640, 3), dtype=np.float32)
self.detectAndCompute(image)
mkpts0 = np.zeros((1, 4800, 2), dtype=np.float32)
feats0 = np.zeros((1, 4800, 64), dtype=np.float32)
sc0 = np.zeros((1, 4800), dtype=np.float32)
mkpts1 = np.zeros((1, 4800, 2), dtype=np.float32)
feats1 = np.zeros((1, 4800, 64), dtype=np.float32)
self.match(mkpts0, feats0, sc0, mkpts1, feats1)
def detectAndCompute(self, image_data, mask=None):
input_array = np.expand_dims(image_data.transpose((2, 0, 1)) , axis=0).astype(np.float32)
inputs = {
self.xfeat_ort_session.get_inputs()[0].name: input_array,
}
mkpts0, feats0, sc = self.xfeat_ort_session.run(None, inputs)
return {
"keypoints": mkpts0,
"descriptors": feats0,
"sc": sc,
}
def match(self, mkpts0, feats0, sc0, mkpts1, feats1):
inputs = {
self.matcher_ort_session.get_inputs()[0].name: mkpts0,
self.matcher_ort_session.get_inputs()[1].name: feats0,
self.matcher_ort_session.get_inputs()[2].name: sc0,
self.matcher_ort_session.get_inputs()[3].name: mkpts1,
self.matcher_ort_session.get_inputs()[4].name: feats1,
}
matches, batch_indexes = self.matcher_ort_session.run(None, inputs)
return matches, batch_indexes
realtime_demo.py patches
diff --git a/realtime_demo.py b/realtime_demo.py
index 6c867fd..8e21ae2 100644
--- a/realtime_demo.py
+++ b/realtime_demo.py
@@ -7,20 +7,18 @@
import cv2
import numpy as np
-import torch
from time import time, sleep
import argparse, sys, tqdm
import threading
-from modules.xfeat import XFeat
def argparser():
parser = argparse.ArgumentParser(description="Configurations for the real-time matching demo.")
parser.add_argument('--width', type=int, default=640, help='Width of the video capture stream.')
parser.add_argument('--height', type=int, default=480, help='Height of the video capture stream.')
parser.add_argument('--max_kpts', type=int, default=3_000, help='Maximum number of keypoints.')
- parser.add_argument('--method', type=str, choices=['ORB', 'SIFT', 'XFeat'], default='XFeat', help='Local feature detection method to use.')
+ parser.add_argument('--method', type=str, choices=['ORB', 'SIFT', 'XFeat', 'XFeat_Ort'], default='XFeat', help='Local feature detection method to use.')
parser.add_argument('--cam', type=int, default=0, help='Webcam device number.')
return parser.parse_args()
@@ -52,6 +50,7 @@ class CVWrapper():
def __init__(self, mtd):
self.mtd = mtd
def detectAndCompute(self, x, mask=None):
+ import torch
return self.mtd.detectAndCompute(torch.tensor(x).permute(2,0,1).float()[None])[0]
class Method:
@@ -65,7 +64,12 @@ def init_method(method, max_kpts):
elif method == "SIFT":
return Method(descriptor=cv2.SIFT_create(max_kpts, contrastThreshold=-1, edgeThreshold=1000), matcher=cv2.BFMatcher(cv2.NORM_L2, crossCheck=True))
elif method == "XFeat":
+ from modules.xfeat import XFeat
return Method(descriptor=CVWrapper(XFeat(top_k = max_kpts)), matcher=XFeat())
+ elif method == "XFeat_Ort":
+ from xfeat_onnxruntime import XFeat
+ xfeat = XFeat()
+ return Method(descriptor=xfeat, matcher=xfeat)
else:
raise RuntimeError("Invalid Method.")
@@ -200,6 +204,12 @@ class MatchingDemo:
if self.args.method in ['SIFT', 'ORB']:
kp1, des1 = self.ref_precomp
kp2, des2 = self.method.descriptor.detectAndCompute(current_frame, None)
+ elif self.args.method in ['XFeat_Ort']:
+ current = self.method.descriptor.detectAndCompute(current_frame)
+ kpts1, descs1, sc1 = self.ref_precomp['keypoints'], self.ref_precomp['descriptors'], self.ref_precomp['sc']
+ kpts2, descs2 = current['keypoints'], current['descriptors']
+ matches, batch_indexes = self.method.matcher.match(kpts1, descs1, sc1, kpts2, descs2)
+ points1, points2 = matches[batch_indexes == 0][..., :2], matches[batch_indexes == 0][..., 2:]
else:
current = self.method.descriptor.detectAndCompute(current_frame)
kpts1, descs1 = self.ref_precomp['keypoints'], self.ref_precomp['descriptors']
Im slightly busy with graduation this week, I will be back to work on this by the weekend
Im slightly busy with graduation this week, I will be back to work on this by the weekend
No problem, thanks for letting me know. Congratulations on your graduation!
Here is a sample python code for running xfeat_matching.onnx with TensorRT API. This example does not include exception handling, input/output validation, etc., and is for reference only.
update 2024-06-24:
- Update
xfeat_dualscale.onnx, Add support for fewer than 4800 num_keypoints to accommodate low-resolution images - Update
matching.onnx, Add support for different numbers of two sets of features - Add tensorrt demo codes for
xfeat_dualscale.onnxandmatching.onnx
import os
import sys
import time
from typing import Optional, List
from functools import reduce
import numpy as np
import cv2
import tensorrt as trt
import cupy as cp # pip install cupy-cuda12x
from cuda import cuda, cudart # pip install cuda-python
from packaging.version import Version
if Version(trt.__version__) >= Version('9.0.0'):
# This is a simple ASCII-art progress monitor comparable to the C++ version in sample_progress_monitor.
class SimpleProgressMonitor(trt.IProgressMonitor):
def __init__(self):
trt.IProgressMonitor.__init__(self)
self._active_phases = {}
self._step_result = True
def phase_start(self, phase_name, parent_phase, num_steps):
try:
if parent_phase is not None:
nbIndents = 1 + self._active_phases[parent_phase]['nbIndents']
else:
nbIndents = 0
self._active_phases[phase_name] = { 'title': phase_name, 'steps': 0, 'num_steps': num_steps, 'nbIndents': nbIndents }
self._redraw()
except KeyboardInterrupt:
# The phase_start callback cannot directly cancel the build, so request the cancellation from within step_complete.
_step_result = False
def phase_finish(self, phase_name):
try:
del self._active_phases[phase_name]
self._redraw(blank_lines=1) # Clear the removed phase.
except KeyboardInterrupt:
_step_result = False
def step_complete(self, phase_name, step):
try:
self._active_phases[phase_name]['steps'] = step
self._redraw()
return self._step_result
except KeyboardInterrupt:
# There is no need to propagate this exception to TensorRT. We can simply cancel the build.
return False
def _redraw(self, *, blank_lines=0):
# The Python curses module is not widely available on Windows platforms.
# Instead, this function uses raw terminal escape sequences. See the sample documentation for references.
def clear_line():
print('\x1B[2K', end='')
def move_to_start_of_line():
print('\x1B[0G', end='')
def move_cursor_up(lines):
print('\x1B[{}A'.format(lines), end='')
def progress_bar(steps, num_steps):
INNER_WIDTH = 10
completed_bar_chars = int(INNER_WIDTH * steps / float(num_steps))
return '[{}{}]'.format(
'=' * completed_bar_chars,
'-' * (INNER_WIDTH - completed_bar_chars))
# Set max_cols to a default of 200 if not run in interactive mode.
max_cols = os.get_terminal_size().columns if sys.stdout.isatty() else 200
move_to_start_of_line()
for phase in self._active_phases.values():
phase_prefix = '{indent}{bar} {title}'.format(
indent = ' ' * phase['nbIndents'],
bar = progress_bar(phase['steps'], phase['num_steps']),
title = phase['title'])
phase_suffix = '{steps}/{num_steps}'.format(**phase)
allowable_prefix_chars = max_cols - len(phase_suffix) - 2
if allowable_prefix_chars < len(phase_prefix):
phase_prefix = phase_prefix[0:allowable_prefix_chars-3] + '...'
clear_line()
print(phase_prefix, phase_suffix)
for line in range(blank_lines):
clear_line()
print()
move_cursor_up(len(self._active_phases) + blank_lines)
sys.stdout.flush()
FP16_ENABLE = True
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
def check_cuda_err(err):
if isinstance(err, cuda.CUresult):
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("Cuda Error: {}".format(err))
if isinstance(err, cudart.cudaError_t):
if err != cudart.cudaError_t.cudaSuccess:
raise RuntimeError("Cuda Runtime Error: {}".format(err))
else:
raise RuntimeError("Unknown error type: {}".format(err))
def cuda_call(call):
err, res = call[0], call[1:]
check_cuda_err(err)
if len(res) == 1:
res = res[0]
return res
class OutputAllocator(trt.IOutputAllocator):
def __init__(self):
trt.IOutputAllocator.__init__(self)
self.buffers = {}
self.shapes = {}
def reallocate_output(self, tensor_name, memory, size, alignment):
output_dtype = cp.dtype(cp.byte)
output = cp.empty(size, output_dtype)
ptr = output.data.ptr
self.buffers[tensor_name] = output
return ptr
def notify_shape(self, tensor_name, shape):
self.shapes[tensor_name] = tuple(shape)
def get_engine(onnx_file_path, engine_file_path=""):
"""Attempts to load a serialized engine if available, otherwise builds a new TensorRT engine and saves it."""
TRT_LOGGER = trt.Logger()
def build_engine():
"""Takes an ONNX file and creates a TensorRT engine to run inference with"""
with trt.Builder(TRT_LOGGER) as builder, builder.create_network(
0 if Version(trt.__version__) >= Version('9.0.0') else EXPLICIT_BATCH
) as network, builder.create_builder_config() as config, trt.OnnxParser(
network, TRT_LOGGER
) as parser, trt.Runtime(
TRT_LOGGER
) as runtime:
# Parse model file
if not os.path.exists(onnx_file_path):
print(
"ONNX file {} not found.".format(onnx_file_path)
)
return None
print("Loading ONNX file from path {}...".format(onnx_file_path))
with open(onnx_file_path, "rb") as model:
print("Beginning ONNX file parsing")
if not parser.parse(model.read()):
print("ERROR: Failed to parse the ONNX file.")
for error in range(parser.num_errors):
print(parser.get_error(error))
return None
print("Completed parsing of ONNX file")
if Version(trt.__version__) >= Version('9.0.0'):
config.progress_monitor = SimpleProgressMonitor()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 4 << 30) # 4GB
# config.max_workspace_size = 1 << 30 # 1GB tensorrt < 8.4
# builder.max_batch_size = 1
for i in range(network.num_inputs):
input_node = network.get_input(i)
print('input ', i, input_node.name, input_node.shape)
assert network.num_inputs == 2, "For xfeat_matching.onnx, only supports two input nodes."
profile = builder.create_optimization_profile()
profile.set_shape(network.get_input(0).name, (1, 3, 64, 64), (8, 3, 640, 640), (32, 3, 1280, 1280))
profile.set_shape(network.get_input(1).name, (1, 3, 64, 64), (8, 3, 640, 640), (32, 3, 1280, 1280))
config.add_optimization_profile(profile)
if builder.platform_has_fast_fp16 and FP16_ENABLE:
print("FP16 mode enabled")
config.set_flag(trt.BuilderFlag.FP16)
print("Building an engine from file {}; this may take a while...".format(onnx_file_path))
plan = builder.build_serialized_network(network, config)
engine = runtime.deserialize_cuda_engine(plan)
print("Completed creating Engine")
with open(engine_file_path, "wb") as f:
f.write(plan)
return engine
if os.path.exists(engine_file_path):
# If a serialized engine exists, use it instead of building an engine.
print("Reading engine from file {}".format(engine_file_path))
try:
with open(engine_file_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
engine = runtime.deserialize_cuda_engine(f.read())
if engine == None:
print("Deserialization of the engine from {} failed. Falling back to building the engine".format(engine_file_path))
return build_engine()
return engine
except Exception as e:
print(e)
print("Deserialization of the engine from {} failed. Falling back to building the engine".format(engine_file_path))
return build_engine()
else:
return build_engine()
# This function is generalized for multiple inputs/outputs for full dimension networks.
# inputs and outputs are expected to be lists of HostDeviceMem objects.
def do_inference(context, engine, inputs_dict, output_allocator, stream):
num_io = engine.num_io_tensors
outputs = []
for i in range(num_io):
name = engine.get_tensor_name(i)
if engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
context.set_tensor_address(name, inputs_dict[name].data.ptr)
context.set_input_shape(name, inputs_dict[name].shape)
context.execute_async_v3(stream_handle=stream)
cuda_call(cudart.cudaStreamSynchronize(stream))
for i in range(num_io):
name = engine.get_tensor_name(i)
if engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT:
output_shape = output_allocator.shapes[name]
output_dtype = cp.dtype(trt.nptype(engine.get_tensor_dtype(name)))
output_size = output_dtype.itemsize * reduce((lambda x, y: x * y), output_shape, 1)
outputs.append(output_allocator.buffers[name][:output_size].view(output_dtype).reshape(output_shape))
return outputs
class XFEAT:
def __init__(self, modelPath, gpu_id=0) -> None:
self.gpu_id = gpu_id
# Init use gpu_id
cudart.cudaSetDevice(self.gpu_id)
engine_file_path = os.path.splitext(modelPath)[0] + ".trt"
self.engine = get_engine(modelPath, engine_file_path)
self.context = self.engine.create_execution_context()
self.stream = cuda_call(cudart.cudaStreamCreate())
self.output_allocator = OutputAllocator()
for i in range(self.engine.num_io_tensors):
name = self.engine.get_tensor_name(i)
if self.engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT:
self.context.set_output_allocator(name, self.output_allocator)
tensor_names = [self.engine.get_tensor_name(i) for i in range(self.engine.num_io_tensors)]
for tensor_name in tensor_names:
print(tensor_name, self.engine.get_tensor_shape(tensor_name))
def inference_preprocessed(self, input_array_0, input_array_1):
cudart.cudaSetDevice(self.gpu_id)
inputs_dict = {
"images0": input_array_0,
"images1": input_array_1,
}
trt_outputs = do_inference(self.context, self.engine, inputs_dict=inputs_dict, output_allocator=self.output_allocator, stream=self.stream)
matches = cp.asnumpy(trt_outputs[0])
batch_indexes = cp.asnumpy(trt_outputs[1])
return matches, batch_indexes
def __call__(self, input_array_0, input_array_1):
input_array_device_0 = cp.asarray(input_array_0, dtype=cp.float32)
input_array_device_1 = cp.asarray(input_array_1, dtype=cp.float32)
return self.inference_preprocessed(input_array_device_0, input_array_device_1)
def __del__(self) -> None:
if hasattr(self, 'stream'):
if cudart != None:
cuda_call(cudart.cudaStreamDestroy(self.stream))
def warp_corners_and_draw_matches(ref_points, dst_points, img1, img2):
# Calculate the Homography matrix
H, mask = cv2.findHomography(ref_points, dst_points, cv2.USAC_MAGSAC, 3.5, maxIters=1_000, confidence=0.999)
mask = mask.flatten()
# Get corners of the first image (image1)
h, w = img1.shape[:2]
corners_img1 = np.array([[0, 0], [w-1, 0], [w-1, h-1], [0, h-1]], dtype=np.float32).reshape(-1, 1, 2)
# Warp corners to the second image (image2) space
warped_corners = cv2.perspectiveTransform(corners_img1, H)
# Draw the warped corners in image2
img2_with_corners = img2.copy()
for i in range(len(warped_corners)):
start_point = tuple(warped_corners[i-1][0].astype(int))
end_point = tuple(warped_corners[i][0].astype(int))
cv2.line(img2_with_corners, start_point, end_point, (0, 255, 0), 4) # Using solid green for corners
# Prepare keypoints and matches for drawMatches function
keypoints1 = [cv2.KeyPoint(p[0], p[1], 5) for p in ref_points]
keypoints2 = [cv2.KeyPoint(p[0], p[1], 5) for p in dst_points]
matches = [cv2.DMatch(i,i,0) for i in range(len(mask)) if mask[i]]
# Draw inlier matches
img_matches = cv2.drawMatches(img1, keypoints1, img2_with_corners, keypoints2, matches, None,
matchColor=(0, 255, 0), flags=2)
return img_matches
if __name__ == "__main__":
onnx_file_path = './models/xfeat_matching.onnx'
input_image_0_path = './images/ref.png'
input_image_1_path = './images/tgt.png'
xfeat = XFEAT(onnx_file_path, 0)
image_0 = cv2.imread(input_image_0_path, cv2.IMREAD_COLOR)
image_1 = cv2.imread(input_image_1_path, cv2.IMREAD_COLOR)
batch_zise = 8 # Psuedo-batch the input images
input_array_0 = np.expand_dims(image_0.transpose(2, 0, 1), axis=0).repeat(batch_zise, axis=0)
input_array_1 = np.expand_dims(image_1.transpose(2, 0, 1), axis=0).repeat(batch_zise, axis=0)
matches, batch_indexes = xfeat(input_array_0, input_array_1)
mkpts_0, mkpts_1 = matches[batch_indexes == 0][..., :2], matches[batch_indexes == 0][..., 2:]
img_matches = warp_corners_and_draw_matches(mkpts_0, mkpts_1, image_0, image_1)
cv2.imshow("Matches", img_matches)
cv2.waitKey(0)
loop = 100
start = time.time()
for i in range(loop):
matches, batch_indexes = xfeat(input_array_0, input_array_1)
end = time.time()
print("Time: ", (end - start) / loop / batch_zise)
print("FPS: ", batch_zise * loop / (end - start))
Hi @acai66 @IamShubhamGupto ,
Thank you very much for providing the ONNX examples in both C++ and Python. They look amazing and incredibly useful!
Are you still waiting for @IamShubhamGupto to review the merge? I am not sure if you both merged the work you did last month, so I am just checking in to see if I should start reviewing the PR.
Once again, thank you @acai66 and @IamShubhamGupto for the ONNX examples. I appreciate your effort to make XFeat deployment much better!
Hi @acai66 @IamShubhamGupto ,
Thank you very much for providing the ONNX examples in both C++ and Python. They look amazing and incredibly useful!
Are you still waiting for @IamShubhamGupto to review the merge? I am not sure if you both merged the work you did last month, so I am just checking in to see if I should start reviewing the PR.
Once again, thank you @acai66 and @IamShubhamGupto for the ONNX examples. I appreciate your effort to make XFeat deployment much better!
Thank you for checking in. I am still waiting for @IamShubhamGupto's review. Unfortunately, I haven't received any response from @IamShubhamGupto regarding last month's work, which might be due to the busy graduation season.
Hi @acai66 @IamShubhamGupto , Thank you very much for providing the ONNX examples in both C++ and Python. They look amazing and incredibly useful! Are you still waiting for @IamShubhamGupto to review the merge? I am not sure if you both merged the work you did last month, so I am just checking in to see if I should start reviewing the PR. Once again, thank you @acai66 and @IamShubhamGupto for the ONNX examples. I appreciate your effort to make XFeat deployment much better!
Thank you for checking in. I am still waiting for @IamShubhamGupto's review. Unfortunately, I haven't received any response from @IamShubhamGupto regarding last month's work, which might be due to the busy graduation season.
Hey @acai66 @guipotje sorry to keep both of you waiting. I did just graduate and was busy with a few conferences and competitions. As for the development on this branch, I believe you should go ahead and merge this branch. I will be back on contributing to this project some time later.
@acai66 Hi, thank you for your contribution. I found that only detectAndComputeDense is supported. Could you also make it support for detectAndCompute?
@acai66 Hi, thank you for your contribution. I found that only
detectAndComputeDenseis supported. Could you also make it support fordetectAndCompute?
commit: https://github.com/verlab/accelerated_features/pull/5/commits/f4a55c14fddbf61ea1f6de83c8a30baacecdc88b
@acai66 Thank you very much!
@acai66 Have you met a problem in DetectAndCompute that the #Select top-k features function exceeded the max number when using TRT Execution Provider.
Hey @acai66,
thanks for all your work on the ONNX export of XFeat, it's been very handy. @guipotje recently added the Lighterglue addon matcher so I spent some time to make that available for ONNX export on top of your changes, see the branch here:
https://github.com/stschake/accelerated_features/tree/feature/lighterglue-onnx
The upstream code uses kornia which isn't suitable for ONNX export, so I started with the LightGlue-ONNX implementation and modified it slightly to add things like keypoints normalization directly in the model, in the xfeat tradition.
Thanks guys! Looking forward for onnx version.