Face-Restoration-TensorRT
Face-Restoration-TensorRT copied to clipboard
A python script to generate trt engine
My code is for 512*512. The code is simple but can run.
import logging
import os.path
import sys
import time
from cuda import cudart
import tensorrt as trt
logger = logging.getLogger()
logging.basicConfig(level=logging.DEBUG, format=f"%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s")
def build_engine(onnx_file_path, trt_model_path, max_workspace_size=8* 1 << 30, fp16_mode=False):
TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE)
# Notice use device 6
cudart.cudaSetDevice(6)
# os.environ['CUDA_VISIBLE_DEVICES'] = '6'
builder = trt.Builder(TRT_LOGGER)
# network = builder.create_network(flags=1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
network = builder.create_network(flags=1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, TRT_LOGGER)
with open(onnx_file_path, 'rb') as model:
if not parser.parse(model.read()):
print("ERROR: Failed to parse the ONNX file.")
for error in range(parser.num_errors):
print(parser.get_error(error))
return None
config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, max_workspace_size)
logger.debug(f"platform_has_fast_fp16 is {builder.platform_has_fast_fp16}")
if fp16_mode and builder.platform_has_fast_fp16:
config.set_flag(trt.BuilderFlag.FP16)
profile = builder.create_optimization_profile()
# 根据您的模型输入名称和形状进行调整
input_name = "input"
min_shape = (1, 3, 512, 512)
opt_shape = (1, 3, 512, 512)
max_shape = (1, 3, 512, 512)
profile.set_shape(input_name, min_shape, opt_shape, max_shape)
config.add_optimization_profile(profile)
logger.info(f"|setup| build_serialized_network")
serialized_engine = builder.build_serialized_network(network, config)
if serialized_engine is None:
logger.error("Failed to build serialized engine.")
return None
with open(trt_model_path, 'wb') as f:
f.write(serialized_engine)
logger.info(f"|setup| create runtime")
runtime = trt.Runtime(TRT_LOGGER)
logger.info(f"|setup| deserialize_cuda_engine")
engine = runtime.deserialize_cuda_engine(serialized_engine)
return engine
def run_build_trt(onnx_file_path, trt_model_path):
logger.debug(f"onnx_file_path => {onnx_file_path}, trt_model_path=>{trt_model_path}")
if not os.path.exists(onnx_file_path):
logger.error(f"{onnx_file_path} is not exists.")
sys.exit(-1)
if os.path.exists(trt_model_path):
logger.error(f"{trt_model_path} is exists !!!.")
sys.exit(-1)
start = time.time()
build_engine(onnx_file_path, trt_model_path)
end = time.time()
logger.info(f"success to run build_engine, result is {trt_model_path}, cost: {end - start}.")
if __name__ == '__main__':
onnx_file_path = "model-512.onnx"
trt_model_path = "gan_512_v1_0625.trt"
run_build_trt(onnx_file_path, trt_model_path)