flux icon indicating copy to clipboard operation
flux copied to clipboard

Add trt support for BF16

Open andompesta opened this issue 3 months ago • 0 comments

This pull request aims to add support for TensortRT integration

Summary of Changes
  1. CLI Support for TensorRT

    • Added a new variable in src/flux/cli.py to enable TensorRT inference.
    • environment variable TRT_ENGINE_DIR specifies the directory for storing TensorRT engines
    • environment variable ONNX_DIRspecifies the directory for ONNX model exports.
    • surrently supports the bf16 (fp16 and fp8 coming soon)
    • supports for model offloading can be added
  2. Modifications for ONNX Export

  3. TensorRT Exporter

    • Added the flux/trt/exporter package, containing code to export PyTorch models to ONNX and build TensorRT engines.
  4. TensorRT Engine Execution

    • flux/trt/engine package is responsible to execute inference using TRT
  5. TensorRT Mixin Classes

    • Added the flux/trt/mixin package with mixin classes to share parameters between the model building and inference phases.
  6. TensorRT Manager

    • Introduced flux/trt/trt_manager.py as the main TensorRT management class. It handles the conversion of PyTorch models to TensorRT engines and manages the TensorRT context for inference.

andompesta avatar Nov 14 '24 22:11 andompesta