TensorRT icon indicating copy to clipboard operation
TensorRT copied to clipboard

🐛 [Bug] MutableTorchTensorRTModule.save error out in low_vram_mode

Open lanluo-nvidia opened this issue 4 months ago • 1 comments

Bug Description

python examples/apps/flux_demo.py --dtype fp16 --low_vram_mode --load_or_save save

WARNING:torch_tensorrt.dynamo.runtime._MutableTorchTensorRTModule:Trying to move the original PyTorch model. This will cause CPU offloading failing and increase GPU memory usage.If this is absolute necessary, please call module.pytorch_model.to(...) The model is still on the original device. Traceback (most recent call last): File "C:\Users\local-lanl\git\py310\TensorRT\examples\apps\flux_demo.py", line 319, in main(args) File "C:\Users\local-lanl\git\py310\TensorRT\examples\apps\flux_demo.py", line 270, in main torch_tensorrt.MutableTorchTensorRTModule.save(trt_gm, "mutable_trt_gm.pkl") File "C:\Users\local-lanl\git\venv_py310\lib\site-packages\torch_tensorrt\dynamo\runtime_MutableTorchTensorRTModule.py", line 709, in save torch.save(module, path, pickle_protocol=4) File "C:\Users\local-lanl\git\venv_py310\lib\site-packages\torch\serialization.py", line 967, in save _save( File "C:\Users\local-lanl\git\venv_py310\lib\site-packages\torch\serialization.py", line 1213, in _save pickler.dump(obj) MemoryError

To Reproduce

Steps to reproduce the behavior:

Expected behavior

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0):
  • PyTorch Version (e.g. 1.0):
  • CPU Architecture:
  • OS (e.g., Linux):
  • How you installed PyTorch (conda, pip, libtorch, source):
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version:
  • CUDA version:
  • GPU models and configuration:
  • Any other relevant information:

Additional context

lanluo-nvidia avatar Sep 15 '25 21:09 lanluo-nvidia

here is the flux_demo.py

import argparse
import os
import re
import sys
import time

import gradio as gr
import modelopt.torch.quantization as mtq
import torch
import torch_tensorrt
from accelerate.hooks import remove_hook_from_module
from diffusers import FluxPipeline
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel

DEVICE = "cuda:0"


def compile_model(
    args,
) -> tuple[
    FluxPipeline, FluxTransformer2DModel, torch_tensorrt.MutableTorchTensorRTModule
]:
    use_explicit_typing = False
    if args.use_sdpa:
        # currently use sdpa is not working correctly with flux model, so we don't use it
        # Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py
        sys.path.append(
            os.path.join(os.path.dirname(__file__), "../../tools/llm/torchtrt_ext")
        )
        import register_sdpa

    if args.dtype == "fp4":
        use_explicit_typing = True
        enabled_precisions = {torch.float4_e2m1fn_x2}
        ptq_config = mtq.NVFP4_DEFAULT_CFG
        if args.fp4_mha:
            from modelopt.core.torch.quantization.config import NVFP4_FP8_MHA_CONFIG

            ptq_config = NVFP4_FP8_MHA_CONFIG

    elif args.dtype == "fp8":
        enabled_precisions = {torch.float8_e4m3fn, torch.float16}
        ptq_config = mtq.FP8_DEFAULT_CFG

    elif args.dtype == "int8":
        enabled_precisions = {torch.int8, torch.float16}
        ptq_config = mtq.INT8_DEFAULT_CFG
        ptq_config["quant_cfg"]["*weight_quantizer"]["axis"] = None

    elif args.dtype == "fp16":
        enabled_precisions = {torch.float16}

    print(f"\nUsing {args.dtype}")

    if args.model not in [
        "black-forest-labs/FLUX.1-dev",
        "black-forest-labs/FLUX.1-Kontext-dev",
    ]:
        raise ValueError(f"Model {args.model} is not supported")
    pipe = FluxPipeline.from_pretrained(
        args.model,
        torch_dtype=torch.float16,
    ).to(torch.float16)

    if args.low_vram_mode:
        pipe.enable_model_cpu_offload()
    else:
        pipe.to(DEVICE)

    backbone = pipe.transformer
    backbone.eval()

    def filter_func(name):
        pattern = re.compile(
            r".*(time_emb_proj|time_embedding|conv_in|conv_out|conv_shortcut|add_embedding|pos_embed|time_text_embed|context_embedder|norm_out|x_embedder).*"
        )
        return pattern.match(name) is not None

    def do_calibrate(
        pipe,
        prompt: str,
    ) -> None:
        """
        Run calibration steps on the pipeline using the given prompts.
        """
        image = pipe(
            prompt,
            output_type="pil",
            num_inference_steps=20,
            generator=torch.Generator("cuda").manual_seed(0),
        ).images[0]

    def forward_loop(mod):
        # Switch the pipeline's backbone, run calibration
        pipe.transformer = mod
        do_calibrate(
            pipe=pipe,
            prompt="a dog running in a park",
        )

    if args.dtype != "fp16":
        backbone = mtq.quantize(backbone, ptq_config, forward_loop)
        mtq.disable_quantizer(backbone, filter_func)

    batch_size = 2 if args.dynamic_shapes else 1
    if args.dynamic_shapes:
        BATCH = torch.export.Dim("batch", min=1, max=8)
        dynamic_shapes = {
            "hidden_states": {0: BATCH},
            "encoder_hidden_states": {0: BATCH},
            "pooled_projections": {0: BATCH},
            "timestep": {0: BATCH},
            "txt_ids": {},
            "img_ids": {},
            "guidance": {0: BATCH},
            "joint_attention_kwargs": {},
            "return_dict": None,
        }
    else:
        dynamic_shapes = None

    settings = {
        "strict": False,
        "allow_complex_guards_as_runtime_asserts": True,
        "enabled_precisions": enabled_precisions,
        "truncate_double": True,
        "min_block_size": 1,
        "use_python_runtime": False,
        "immutable_weights": False,
        "offload_module_to_cpu": args.low_vram_mode,
        "use_explicit_typing": use_explicit_typing,
    }
    if args.low_vram_mode:
        pipe.remove_all_hooks()
        pipe.enable_sequential_cpu_offload()
        remove_hook_from_module(pipe.transformer, recurse=True)
        pipe.transformer.to(DEVICE)
    
    if args.load_or_save == "save":
        with torch_tensorrt.logging.debug():
            start = time.time()
            trt_gm = torch_tensorrt.MutableTorchTensorRTModule(backbone, **settings)
            end = time.time()
            print(f"Time taken to compile the model: {end - start} seconds")
    else:
        trt_gm = torch_tensorrt.MutableTorchTensorRTModule.load("mutable_trt_gm.pkl")
    if dynamic_shapes:
        trt_gm.set_expected_dynamic_shape_range((), dynamic_shapes)
    pipe.transformer = trt_gm
    seed = 42
    image = pipe(
        [
            "enchanted winter forest, soft diffuse light on a snow-filled day, serene nature scene, the forest is illuminated by the snow"
        ],
        output_type="pil",
        num_inference_steps=30,
        num_images_per_prompt=batch_size,
        generator=torch.Generator("cuda").manual_seed(seed),
    ).images
    print(f"generated {len(image)} images")
    image[0].save("forest.png")

    torch.cuda.empty_cache()

    if args.low_vram_mode:
        pipe.remove_all_hooks()
        pipe.to(DEVICE)

    return pipe, backbone, trt_gm


def launch_gradio(pipeline, backbone, trt_gm):

    def generate_image(prompt, inference_step, batch_size=2):
        start_time = time.time()
        image = pipeline(
            prompt,
            output_type="pil",
            num_inference_steps=inference_step,
            num_images_per_prompt=batch_size,
        ).images
        end_time = time.time()
        return image, end_time - start_time

    def model_change(model):
        if model == "Torch Model":
            pipeline.transformer = backbone
            backbone.to(DEVICE)
        else:
            backbone.to("cpu")
            pipeline.transformer = trt_gm
            torch.cuda.empty_cache()

    def load_lora(path):
        pipeline.load_lora_weights(
            path,
            adapter_name="lora1",
        )
        pipeline.set_adapters(["lora1"], adapter_weights=[1])
        pipeline.fuse_lora()
        pipeline.unload_lora_weights()
        print("LoRA loaded! Begin refitting")
        generate_image(pipeline, ["Test"], 2)
        print("Refitting Finished!")

    # Create Gradio interface
    with gr.Blocks(title="Flux Demo with Torch-TensorRT") as demo:
        gr.Markdown("# Flux Image Generation Demo Accelerated by Torch-TensorRT")

        with gr.Row():
            with gr.Column():
                # Input components
                prompt_input = gr.Textbox(
                    label="Prompt", placeholder="Enter your prompt here...", lines=3
                )
                model_dropdown = gr.Dropdown(
                    choices=["Torch Model", "Torch-TensorRT Accelerated Model"],
                    value="Torch-TensorRT Accelerated Model",
                    label="Model Variant",
                )

                lora_upload_path = gr.Textbox(
                    label="LoRA Path",
                    placeholder="Enter the LoRA checkpoint path here. It could be a local path or a Hugging Face URL.",
                    value="gokaygokay/Flux-Engrave-LoRA",
                    lines=2,
                )
                num_steps = gr.Slider(
                    minimum=20, maximum=100, value=20, step=1, label="Inference Steps"
                )
                batch_size = gr.Slider(
                    minimum=1, maximum=8, value=1, step=1, label="Batch Size"
                )

                generate_btn = gr.Button("Generate Image")
                load_lora_btn = gr.Button("Load LoRA")

            with gr.Column():
                # Output component
                output_image = gr.Gallery(label="Generated Image")
                time_taken = gr.Textbox(
                    label="Generation Time (seconds)", interactive=False
                )

        # Connect the button to the generation function
        model_dropdown.change(model_change, inputs=[model_dropdown])
        load_lora_btn.click(
            fn=load_lora,
            inputs=[
                lora_upload_path,
            ],
        )

        # Update generate button click to include time output
        generate_btn.click(
            fn=generate_image,
            inputs=[
                prompt_input,
                num_steps,
                batch_size,
            ],
            outputs=[output_image, time_taken],
        )
        demo.launch()


def main(args):
    pipe, backbone, trt_gm = compile_model(args)
    if args.load_or_save == "save":
        torch_tensorrt.MutableTorchTensorRTModule.save(trt_gm, "mutable_trt_gm.pkl")
    # launch_gradio(pipe, backbone, trt_gm)


# Launch the interface
if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Run Flux quantization with different dtypes"
    )
    parser.add_argument(
        "--load_or_save",
        choices=["load", "save"],
        default="",
        help="load or save",
    )
    parser.add_argument(
        "--model",
        default="black-forest-labs/FLUX.1-dev",
        help="Model to use",
    )
    parser.add_argument(
        "--use_sdpa",
        action="store_true",
        help="Use sdpa",
        default=False,
    )
    parser.add_argument(
        "--dtype",
        choices=["fp4", "fp8", "int8", "fp16"],
        default="fp16",
        help="Select the data type to use (fp4 or fp8 or int8 or fp16)",
    )
    parser.add_argument(
        "--fp4_mha",
        action="store_true",
        help="Use NVFP4_FP8_MHA_CONFIG config instead of NVFP4_DEFAULT_CFG",
    )
    parser.add_argument(
        "--low_vram_mode",
        action="store_true",
        help="Use low VRAM mode when you have a small GPU (<=32GB)",
    )
    parser.add_argument(
        "--dynamic_shapes",
        "-d",
        action="store_true",
        help="Use dynamic shapes",
    )
    args = parser.parse_args()
    main(args)

lanluo-nvidia avatar Sep 15 '25 21:09 lanluo-nvidia