Face-Restoration-TensorRT icon indicating copy to clipboard operation
Face-Restoration-TensorRT copied to clipboard

A python script to generate trt engine

Open goingHan opened this issue 9 months ago • 0 comments

My code is for 512*512. The code is simple but can run.

import logging
import os.path
import sys
import time
from cuda import cudart
import tensorrt as trt

logger = logging.getLogger()
logging.basicConfig(level=logging.DEBUG, format=f"%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s")


def build_engine(onnx_file_path, trt_model_path, max_workspace_size=8* 1 << 30, fp16_mode=False):
    TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE)
    # Notice use device 6 
    cudart.cudaSetDevice(6)
    # os.environ['CUDA_VISIBLE_DEVICES'] = '6'
    builder = trt.Builder(TRT_LOGGER)
    # network = builder.create_network(flags=1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
    network = builder.create_network(flags=1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
    parser = trt.OnnxParser(network, TRT_LOGGER)

    with open(onnx_file_path, 'rb') as model:
        if not parser.parse(model.read()):
            print("ERROR: Failed to parse the ONNX file.")
            for error in range(parser.num_errors):
                print(parser.get_error(error))
            return None

    config = builder.create_builder_config()
    config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, max_workspace_size)

    logger.debug(f"platform_has_fast_fp16 is {builder.platform_has_fast_fp16}")
    if fp16_mode and builder.platform_has_fast_fp16:
        config.set_flag(trt.BuilderFlag.FP16)

    profile = builder.create_optimization_profile()

    # 根据您的模型输入名称和形状进行调整
    input_name = "input"
    min_shape = (1, 3, 512, 512)
    opt_shape = (1, 3, 512, 512)
    max_shape = (1, 3, 512, 512)
    profile.set_shape(input_name, min_shape, opt_shape, max_shape)
    config.add_optimization_profile(profile)
    logger.info(f"|setup| build_serialized_network")
    serialized_engine = builder.build_serialized_network(network, config)
    if serialized_engine is None:
        logger.error("Failed to build serialized engine.")
        return None

    with open(trt_model_path, 'wb') as f:
        f.write(serialized_engine)
    logger.info(f"|setup| create runtime")
    runtime = trt.Runtime(TRT_LOGGER)
    logger.info(f"|setup| deserialize_cuda_engine")
    engine = runtime.deserialize_cuda_engine(serialized_engine)
    return engine


def run_build_trt(onnx_file_path, trt_model_path):
    logger.debug(f"onnx_file_path => {onnx_file_path}, trt_model_path=>{trt_model_path}")

    if not os.path.exists(onnx_file_path):
        logger.error(f"{onnx_file_path} is not exists.")
        sys.exit(-1)
    if os.path.exists(trt_model_path):
        logger.error(f"{trt_model_path} is  exists !!!.")
        sys.exit(-1)
    start = time.time()
    build_engine(onnx_file_path, trt_model_path)
    end = time.time()
    logger.info(f"success to run build_engine, result is {trt_model_path}, cost: {end - start}.")


if __name__ == '__main__':
    onnx_file_path = "model-512.onnx"
    trt_model_path = "gan_512_v1_0625.trt"
    run_build_trt(onnx_file_path, trt_model_path)

goingHan avatar Jun 26 '25 09:06 goingHan