TensorRT icon indicating copy to clipboard operation
TensorRT copied to clipboard

🐛 [Bug] Can't load UNet on H100 after compiling ExportedProgram with torch_tensorrt.dynamo.compile and saving

Open readleyj opened this issue 1 year ago • 13 comments

Bug Description

I am trying to use torch_tensorrt.dynamo.compile() to AOT compile the UNet portion of a StableDiffusionPipeline from the diffusers library (version 0.30.2). I am able to export the UNet with torch.export.export(), compile it with torch_tensorrt.dynamo.compile() and save it with torch_tensorrt.save(). However, I am encountering a runtime error when attempting to load the saved compiled UNet with torch.export.load().

To Reproduce

Run the code below

import functools

import torch
import torch_tensorrt

from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler

def generate_sd_unet_inputs():
    sample = torch.randn((2, 4, 64, 64), device="cuda", dtype=torch.float16)
    timestep = torch.rand([], device="cuda", dtype=torch.float32) * 999
    encoder_hidden_states = torch.randn((2, 77, 768), device="cuda", dtype=torch.float16)
    
    return sample, timestep, encoder_hidden_states

with torch.inference_mode():
    pipe = StableDiffusionPipeline.from_pretrained(
        "CompVis/stable-diffusion-v1-4",
        torch_dtype=torch.float16,
        variant="fp16",
        use_safetensors=True,
    ).to("cuda")
    pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)

    unet_model = pipe.unet.eval()
    unet_model.forward = functools.partial(unet_model.forward, return_dict=False)
    
    arg_inputs_unet = generate_sd_unet_inputs()
    expected_outputs_unet = unet_model(*arg_inputs_unet)
    
    unet_exported_program = torch.export.export(unet_model, arg_inputs_unet)
        
    with torch_tensorrt.logging.errors():
        compiled_unet = torch_tensorrt.dynamo.compile(
            unet_exported_program,
            inputs=arg_inputs_unet,
            enabled_precisions={torch.float16, torch.float32},
            truncate_double=True,
        )
    
    torch_tensorrt.save(compiled_unet, "sd_unet_compiled.ep", inputs=arg_inputs_unet)
    loaded_unet = torch.export.load("sd_unet_compiled.ep").module()

Error message

...
WARNING: [Torch-TensorRT] - Using default stream in enqueueV3() may lead to performance issues due to additional calls to cudaStreamSynchronize() by TensorRT to ensure correct synchronization. Please use non-default stream instead.
WARNING:py.warnings:/home/ismayilismayilov/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch_tensorrt/dynamo/_exporter.py:370: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer
  engine_node = gm.graph.get_attr(engine_name)

WARNING:py.warnings:/home/ismayilismayilov/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/fx/graph.py:1586: UserWarning: Node _run_on_acc_0_engine target _run_on_acc_0_engine _run_on_acc_0_engine of  does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
  warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '

WARNING:py.warnings:/home/ismayilismayilov/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/fx/graph.py:1586: UserWarning: Node _run_on_acc_2_engine target _run_on_acc_2_engine _run_on_acc_2_engine of  does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
  warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '

WARNING:py.warnings:/home/ismayilismayilov/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/fx/graph.py:1586: UserWarning: Node _run_on_acc_4_engine target _run_on_acc_4_engine _run_on_acc_4_engine of  does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
  warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '

WARNING:py.warnings:/home/ismayilismayilov/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/fx/graph.py:1586: UserWarning: Node _run_on_acc_6_engine target _run_on_acc_6_engine _run_on_acc_6_engine of  does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
  warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '

WARNING:py.warnings:/home/ismayilismayilov/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/fx/graph.py:1586: UserWarning: Node _run_on_acc_8_engine target _run_on_acc_8_engine _run_on_acc_8_engine of  does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
  warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '

WARNING:py.warnings:/home/ismayilismayilov/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/fx/graph.py:1593: UserWarning: Additional 16 warnings suppressed about get_attr references
  warnings.warn(

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[1], line 48
     40     compiled_unet = torch_tensorrt.dynamo.compile(
     41         unet_exported_program,
     42         inputs=arg_inputs_unet,
     43         enabled_precisions={torch.float16, torch.float32},
     44         truncate_double=True,
     45     )
     47 torch_tensorrt.save(compiled_unet, "sd_unet_compiled.ep", inputs=arg_inputs_unet)
---> 48 loaded_unet = torch.export.load("sd_unet_compiled.ep")

File [~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/export/__init__.py:476](https://jhub.codeway-int.com/user/ismayilismayilov/lab/tree/ismayilismayilov/~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/export/__init__.py#line=475), in load(f, extra_files, expected_opset_version)
    468 artifact: SerializedArtifact = SerializedArtifact(
    469     serialized_exported_program,
    470     serialized_state_dict,
    471     serialized_constants,
    472     serialized_example_inputs,
    473 )
    475 # Deserialize ExportedProgram
--> 476 ep = deserialize(artifact, expected_opset_version)
    478 return ep

File [~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/_export/serde/serialize.py:2437](https://jhub.codeway-int.com/user/ismayilismayilov/lab/tree/ismayilismayilov/~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/_export/serde/serialize.py#line=2436), in deserialize(artifact, expected_opset_version)
   2433 exported_program_dict = json.loads(exported_program_str)
   2434 serialized_exported_program = _dict_to_dataclass(ExportedProgram, exported_program_dict)
   2435 return (
   2436     ExportedProgramDeserializer(expected_opset_version)
-> 2437     .deserialize(
   2438         serialized_exported_program,
   2439         artifact.state_dict,
   2440         artifact.constants,
   2441         artifact.example_inputs,
   2442     )
   2443 )

File [~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/_export/serde/serialize.py:2329](https://jhub.codeway-int.com/user/ismayilismayilov/lab/tree/ismayilismayilov/~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/_export/serde/serialize.py#line=2328), in ExportedProgramDeserializer.deserialize(self, exported_program, state_dict, constants, example_inputs)
   2314 res = (
   2315     GraphModuleDeserializer()
   2316     .deserialize(
   (...)
   2322     )
   2323 )
   2324 range_constraints = self.deserialize_range_constraints(
   2325     symbol_name_to_range,
   2326     res.names_to_symbols,
   2327 )
-> 2329 return ep.ExportedProgram(
   2330     root=res.graph_module,
   2331     graph=res.graph_module.graph,
   2332     graph_signature=res.signature,
   2333     state_dict=res.state_dict,  # type: ignore[arg-type]
   2334     range_constraints=range_constraints,
   2335     module_call_graph=res.module_call_graph,
   2336     example_inputs=res.example_inputs,
   2337     constants=res.constants,
   2338     verifiers=[load_verifier(v) for v in exported_program.verifiers],
   2339 )

File [~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/export/exported_program.py:700](https://jhub.codeway-int.com/user/ismayilismayilov/lab/tree/ismayilismayilov/~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/export/exported_program.py#line=699), in ExportedProgram.__init__(self, root, graph, graph_signature, state_dict, range_constraints, module_call_graph, example_inputs, constants, verifiers)
    698 self._verifiers = verifiers
    699 # Validate should be always the last step of the constructor.
--> 700 self.validate()

File [~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/export/exported_program.py:1117](https://jhub.codeway-int.com/user/ismayilismayilov/lab/tree/ismayilismayilov/~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/export/exported_program.py#line=1116), in ExportedProgram.validate(self)
   1115 @compatibility(is_backward_compatible=False)
   1116 def validate(self):
-> 1117     self._validate()

File [~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/export/exported_program.py:1126](https://jhub.codeway-int.com/user/ismayilismayilov/lab/tree/ismayilismayilov/~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/export/exported_program.py#line=1125), in ExportedProgram._validate(self)
   1122 assert (
   1123     len(self.verifiers) > 0
   1124 ), "ExportedProgram must have at least one verifier."
   1125 for v in self.verifiers:
-> 1126     v().check(self)

File [~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/_export/verifier.py:155](https://jhub.codeway-int.com/user/ismayilismayilov/lab/tree/ismayilismayilov/~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/_export/verifier.py#line=154), in Verifier.check(self, ep)
    153 @final
    154 def check(self, ep: "ExportedProgram") -> None:
--> 155     self._check_graph_module(ep.graph_module)
    156     _verify_exported_program_module_call_graph(ep)
    157     _verify_exported_program_signature(ep)

File [~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/_export/verifier.py:214](https://jhub.codeway-int.com/user/ismayilismayilov/lab/tree/ismayilismayilov/~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/_export/verifier.py#line=213), in Verifier._check_graph_module(self, gm)
    211 if not isinstance(mod, torch.fx.GraphModule):
    212     continue
--> 214 mod.graph.lint()
    215 for node in mod.graph.nodes:
    216     # TODO(T140410192): should have fake tensor for all dialects
    217     if node.op in {"call_module", "call_method"}:

File [~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/fx/graph.py:1549](https://jhub.codeway-int.com/user/ismayilismayilov/lab/tree/ismayilismayilov/~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/fx/graph.py#line=1548), in Graph.lint(self)
   1546     seen_values.add(node)
   1548     if node.name in seen_names:
-> 1549         raise RuntimeError(f'Node redefined name {node.name}!')
   1550     seen_names.add(node.name)
   1552 # Check targets are legit

RuntimeError: Node redefined name getitem_130!

Expected behavior

The code should load the saved compiled model without erroring out.

Environment

  • Torch-TensorRT Version (e.g. 1.0.0): 2.5.0.dev20240912+cu124
  • PyTorch Version (e.g. 1.0): 2.5.0.dev20240912+cu124
  • CPU Architecture: x86_64
  • OS (e.g., Linux): Ubuntu 22.04.4 LTS (x86_64)
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version: 3.11.10
  • CUDA version: 12.4
  • GPU models and configuration: 1/2 of an H100 (Configured with MIG)
  • Any other relevant information: Using diffusers version 0.30.2

Additional context

I have to use functools.partial() in the code above because the default output of the pipeline's forward method is the UNet2DConditionOutput dataclass. I tried to get rid of functools.partial() by instead using torch.export.register_dataclass() but was met with the same runtime error mentioned above.

Additionally, saving and loading the ExportedProgram (without Torch-TensorRT compilation) works fine.

readleyj avatar Sep 16 '24 12:09 readleyj