flux
flux copied to clipboard
Add trt support for BF16
This pull request aims to add support for TensortRT integration
Summary of Changes
-
CLI Support for TensorRT
- Added a new variable in src/flux/cli.py to enable TensorRT inference.
- environment variable
TRT_ENGINE_DIR
specifies the directory for storing TensorRT engines - environment variable
ONNX_DIR
specifies the directory for ONNX model exports. - surrently supports the
bf16
(fp16
andfp8
coming soon) - supports for model offloading can be added
-
Modifications for ONNX Export
- Minor changes to src/flux/math and src/flux/modules/autoencoder.py to enable proper export in ONNX
- Additional changes are required to address numerical stability issues within the Flux-Transformer model.
-
TensorRT Exporter
- Added the flux/trt/exporter package, containing code to export PyTorch models to ONNX and build TensorRT engines.
-
TensorRT Engine Execution
- flux/trt/engine package is responsible to execute inference using TRT
-
TensorRT Mixin Classes
- Added the flux/trt/mixin package with mixin classes to share parameters between the model building and inference phases.
-
TensorRT Manager
- Introduced flux/trt/trt_manager.py as the main TensorRT management class. It handles the conversion of PyTorch models to TensorRT engines and manages the TensorRT context for inference.