TensorRT
TensorRT copied to clipboard
🐛 [Bug] Encountered bug when using Torch-TensorRT
Bug Description
https://github.com/pytorch/TensorRT/blob/main/examples/dynamo/mutable_torchtrt_module_example.py I replaced hugging face whisper model instead of diffusion model
To Reproduce
import numpy as np import torch import torch_tensorrt as torch_trt from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline #from datasets import load_dataset from peft import PeftModel #import torchvision.models as models
np.random.seed(5) torch.manual_seed(5) inputs = [torch.rand((1, 3, 224, 224)).to("cuda")]
%%
Initialize the Mutable Torch TensorRT Module with settings.
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
device = "cuda:0" if torch.cuda.is_available() else "cpu" torch_dtype = torch.float16
model_id = "distil-whisper/distil-large-v3"
with torch.no_grad(): settings = { "use_python_runtime": True, "enabled_precisions": {torch.float16}, "debug": True, "make_refitable": True, }
base_model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
)
base_model.eval().to(device)
#model = models.resnet18(pretrained=True).eval().to("cuda") mutable_module = torch_trt.MutableTorchTensorRTModule(base_model, **settings)
Steps to reproduce the behavior:
1.Please run above code
Traceback (most recent call last): File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step self.dispatch_table[inst.opcode](self, inst) File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper return inner_fn(self, inst) File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1680, in CALL_FUNCTION_EX self.call_function(fn, argsvars.items, kwargsvars) File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type] File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/variables/nn_module.py", line 442, in call_function return tx.inline_user_function_return( File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/dynamo/symbolic_convert.py", line 3011, in inline_call return cls.inline_call(parent, func, args, kwargs) File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/dynamo/symbolic_convert.py", line 3139, in inline_call tracer.run() File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run while self.step(): File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 898, in step self.exception_handler(e) File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1496, in exception_handler raise raised_exception File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step self.dispatch_table[inst.opcode](self, inst) File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper return inner_fn(self, inst) File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1680, in CALL_FUNCTION_EX self.call_function(fn, argsvars.items, kwargsvars) File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type] File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 385, in call_function return super().call_function(tx, args, kwargs) File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 324, in call_function return super().call_function(tx, args, kwargs) File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 108, in call_function return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs) File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/dynamo/symbolic_convert.py", line 3011, in inline_call return cls.inline_call(parent, func, args, kwargs) File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/dynamo/symbolic_convert.py", line 3139, in inline_call tracer.run() File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run while self.step(): File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 898, in step self.exception_handler(e) File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1496, in exception_handler raise raised_exception File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step self.dispatch_table[inst.opcode](self, inst) File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper return inner_fn(self, inst) File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1692, in CALL_FUNCTION_KW self.call_function(fn, args, kwargs) File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type] File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/variables/nn_module.py", line 442, in call_function return tx.inline_user_function_return( File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/dynamo/symbolic_convert.py", line 3011, in inline_call return cls.inline_call(parent, func, args, kwargs) File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/dynamo/symbolic_convert.py", line 3139, in inline_call tracer.run() File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run while self.step(): File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 898, in step self.exception_handler(e) File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1496, in exception_handler raise raised_exception File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step self.dispatch_table[inst.opcode](self, inst) File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper return inner_fn(self, inst) File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1680, in CALL_FUNCTION_EX self.call_function(fn, argsvars.items, kwargsvars) File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type] File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 385, in call_function return super().call_function(tx, args, kwargs) File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 324, in call_function return super().call_function(tx, args, kwargs) File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 108, in call_function return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs) File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/dynamo/symbolic_convert.py", line 3011, in inline_call return cls.inline_call(parent, func, args, kwargs) File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/dynamo/symbolic_convert.py", line 3139, in inline_call tracer.run() File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run while self.step(): File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 898, in step self.exception_handler(e) File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1496, in exception_handler raise raised_exception File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step self.dispatch_table[inst.opcode](self, inst) File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1388, in RAISE_VARARGS self._raise_exception_variable(inst) File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1381, in _raise_exception_variable raise exc.ObservedException(f"raised exception {val}") torch._dynamo.exc.ObservedException: raised exception ExceptionVariable()
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/jalagurajah/Desktop/ROME/My Documents/ASR_2024/torchrttest.py", line 42, in
from user code: File "/home/jalagurajah/anaconda3/envs/torchrt/lib/python3.10/site-packages/transformers/models/whisper/modeling_whisper.py", line 1764, in forward outputs = self.model(
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
Expected behavior
complie without problem
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
- Torch-TensorRT Version : 12.4
- PyTorch Version : 2.5.0.dev20240830+cu124
- CPU Architecture: x86
- OS (e.g., Linux): ubuntu
- How you installed PyTorch (
conda,pip,libtorch, source): nightly install - Python version: 3.10
- CUDA version: 12.4
- GPU models and configuration: A100 80G
Have you tried exporting whisper with torch.export? Does that work properly? Seems like right now that is the step that is failing
Have you tried exporting whisper with torch.export? Does that work properly? Seems like right now that is the step that is failing
I was not able to successfully export whisper model
For using the MutableModule, that is a prerequisite step (it'll either be done by you or by us). It might be worth opening an issue on pytorch/pytorch for this.
Alternatively you can also try torch.compile(..., backend="tensorrt") which is a bit more flexible.
import torch
from datasets import load_dataset
from transformers import WhisperProcessor, WhisperForConditionalGeneration
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
hf_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
audio_sample = ds[0]["audio"]
input_features = processor( audio_sample["array"], sampling_rate=audio_sample["sampling_rate"], return_tensors="pt" ).input_features
with torch.no_grad(): print(hf_model.generate(input_features))
exported_model = torch.export.export(hf_model, args=(input_features,))
torch.export.save(exported_model, "model.pt")
pt_model = torch.export.load('model.pt')
with torch.no_grad(): print(pt_model.module().generate(input_features))
I recreated the error using this above code