TensorRT
TensorRT copied to clipboard
🐛 [Bug] Encountered bug when using Torch-TensorRT(torch.nn.LSTM)
Bug Description
Encountered error as follow when using Torch-TensorRT to convert torch.nn.LSTM in docker image nvcr.io/nvidia/pytorch:23.12-py3 : NotImplementedError: aten::_cudnn_rnn_flatten_weight: attempted to run this operator with Meta tensors, but there was no abstract impl or Meta kernel registered. You may have run into this message while using an operator with PT2 compilation APIs (torch.compile/torch.export); in order to use this operator with those APIs you'll need to add an abstract impl
To Reproduce
example code:
import torch
import torch_tensorrt
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super().__init__()
self.lstm = nn.LSTM(1024, 1024, batch_first=True)
def forward(self, x):
x = self.lstm(x)[0]
return x
model = Model().half().eval().cuda()
inputs = [torch.randn(100, 200, 1024).half().cuda()]
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs=inputs)
trt_gm(*inputs)
Expected behavior
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
- Torch-TensorRT Version: 2.2.0a0
- PyTorch Version: 2.2.0a0+81ea7a4
- CPU Architecture:
- OS (e.g., Linux): Linux
- How you installed PyTorch (
conda,pip,libtorch, source): - Build command you used (if compiling from source):
- Are you using local sources or building from archives:
- Python version: 3.10.12
- CUDA version: 12.3
- GPU models and configuration: A100
- Any other relevant information:
Additional context
@narendasan Hi, is there any update?
Hello - as an update on this issue, a workaround to try is to compile with ir="torch_compile" and specify torch._dynamo.config.allow_rnn = True at the top of the script.
Regarding the ir="dynamo" path, there is a workaround as specified here: https://github.com/pytorch/pytorch/issues/121761#issuecomment-2021696208, which can then be used with Torch-TensorRT by passing the gm object into the .compile call.
A more robust fix is pending resolution to these related issues: [pytorch/pytorch/issues/120626, pytorch/pytorch/issues/121761]