TensorRT
TensorRT copied to clipboard
🐛 [Bug] Compilation causes error: `RuntimeError: [Error thrown at core/partitioning/shape_analysis.cpp:66] Expected ivalues_maps.count(input) to be true but got false Could not find torch::jit::Value* 47 produced from %47 : int = prim::dtype(%52) in lowering graph for mini graph input.`
Bug Description
Compiling the graph throws the following error:
RuntimeError: [Error thrown at core/partitioning/shape_analysis.cpp:66] Expected ivalues_maps.count(input) to be true but got false
Could not find torch::jit::Value* 47 produced from %47 : int = prim::dtype(%52) in lowering graph for mini graph input.
Looking at the output torchscript graph, %47 is defined in a prior node, however, it does not appear to be visible in the current node.
To Reproduce
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_tensorrt as torchtrt
import torch_tensorrt.logging as logging
logging.set_reportable_log_level(logging.Level.Graph)
torch.manual_seed(0)
DEVICE = torch.device("cuda:0")
SHAPE = (1, 1)
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.lin = nn.Linear(1, 1)
def forward(self, a):
out = self.lin(a)
tril = torch.zeros(1, 1, 1, device=a.device, dtype=out.dtype)
indices = torch.tril_indices(1, 1)
tril[:, indices[0], indices[1]] = out
return tril
if __name__ == "__main__":
tensor = torch.randn(SHAPE, dtype=torch.float32, device=DEVICE)
model = Model().eval().to(DEVICE)
out = model(tensor)
print(f"Model: {out}")
model_trt = torchtrt.compile(
model,
inputs=[
torchtrt.Input(shape=SHAPE),
],
enabled_precisions={torch.float},
truncate_long_and_double=True
)
out_trt = model(tensor)
print(f"Model TRT: {out_trt}")
assert torch.max(torch.abs(out - out_trt)) < 1e-6
Throws the following error:
Traceback (most recent call last):
File "/scripts/tril.py", line 39, in <module>
model_trt = torchtrt.compile(
File "/opt/conda/lib/python3.8/site-packages/torch_tensorrt/_compile.py", line 97, in compile
return torch_tensorrt.ts.compile(ts_mod, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/torch_tensorrt/ts/_compiler.py", line 119, in compile
compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
RuntimeError: [Error thrown at core/partitioning/shape_analysis.cpp:66] Expected ivalues_maps.count(input) to be true but got false
Could not find torch::jit::Value* 47 produced from %47 : int = prim::dtype(%52) in lowering graph for mini graph input.
Expected behavior
Compilation should not fail, and should produce the following output when run:
Model: tensor([[[0.5434]]], device='cuda:0', grad_fn=<CopySlices>)
Environment
Ubuntu 18.04 x86-64, run with NGC 21.11-py3
and 22.02-py3
.
Additional context
See output.txt for full torchscript graph output.
This looks similar to issue #756, with fix #757.
Looking at the sources, it looks like this fix may not have made it into /release/ngc/22.02
but should be present in /release/ngc/22.03
and afterwards. At the time of this writing, only 22.02-py3
is available. I'll close this pending testing and availability of 22.03-py3
.
Just tried fix #757 with master commit 11bcb98d3cd680c3c34e6cc4c4efdc7512c144cc
built in NGC container nvcr.io/nvidia/tensorrt:22.02-py3
and PyTorch 1.10, and the error persists, so it's likely that #757 does not address this issue.
Looks like the issue is in this line:
tril = torch.zeros(1, 1, 1, device=a.device, dtype=out.dtype)
as changing this line to:
tril = torch.zeros(1, 1, 1).cuda()
appears to bypass the issue.
@peri044 Can you take a look at this, looks related to your past work on dtype
@chaoz-dev I had the same error
[Error thrown at core/partitioning/shape_analysis.cpp:67] Expected ivalues_maps.count(input) to be true but got false
Could not find torch::jit::Value* 12065 produced from %12065 : int = aten::size(%out.29, %11826), scope: __module.stylegan_decoder/__module.stylegan_decoder.style_conv1/__module.stylegan_decoder.style_conv1.modulated_conv
and your method of bypassing is not applicable in my case. Could you suggest something else.
You're probably facing a different but possibly related issue. Can you file a new bug report with the above information?
I took a look into this issue. This is caused by resolveNonTensorInput function. What happens here is that when you are using:
tril = torch.zeros(1, 1, 1, device=a.device, dtype=out.dtype)
It will introduce a NonTensorInput for a.device for the minigraph, this will induce ResolveNonTensorInput function to segment this subgraph again. This explains why it's fine when you change it to
tril = torch.zeros(1, 1, 1).cuda()
Let me see if we can refactor this function since it's doing a mess here.
@chaoz-dev I raised a PR for this bug just now here https://github.com/NVIDIA/Torch-TensorRT/pull/1024. The model you provided should be supported now. Please take a look and ping me if there is any other issues.
Sorry to re-raise this issue, but I'm still getting the same runtime error for deformable convolutions on the latest build of master (commit 91a92ca4), which includes PR #1024.
Expanding upon the original reproduction code above trt_bug.py, I'm getting
RuntimeError: [Error thrown at core/partitioning/shape_analysis.cpp:68] Expected ivalues_maps.count(input) to be true but got false
Could not find torch::jit::Value* 62 produced from %62 : int = aten::__getitem__(%61, %54) # /home/brett/.local/lib/python3.10/site-packages/torchvision/ops/deform_conv.py:71:28 in lowering graph for mini graph input.
ENV info:
Collecting environment information...
PyTorch version: 1.11.0a0+gitbc2c6ed
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04 LTS (x86_64)
GCC version: (Ubuntu 11.2.0-19ubuntu1) 11.2.0
Clang version: Could not collect
CMake version: version 3.22.1
Libc version: glibc-2.35
Python version: 3.10.4 (main, Apr 2 2022, 09:04:19) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-33-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 11.7.64
GPU models and configuration: GPU 0: NVIDIA GeForce GTX 1080 Ti
Nvidia driver version: 515.43.04
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.3.2
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.3.2
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.3.2
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.3.2
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.3.2
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.3.2
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.3.2
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.21.5
[pip3] torch==1.11.0
[pip3] torch-tensorrt==1.2.0a0+91a92ca4
[pip3] torchvision==0.12.0
[conda] Could not collect
I am still getting this runtime error on another model; I don't believe I'm using deformable convolution here. I'm in the process of trying to clean up source code to show, but the error is something like:
RuntimeError: [Error thrown at core/partitioning/shape_analysis.cpp:68] Expected ivalues_maps.count(input) to be true but got false
Could not find torch::jit::Value* 1442 produced from %x.2 : Tensor, %1442 : Float(32, 12, 3, 3, ...
with the ...
replaced by about 25K characters of a dump of the graph.
Note also that this error did not occur with float16, only with int8.
This is with Torch-TensorRT v1.1.0, so PR #1024 is included.
@BrettRyland Could you please try this PR: https://github.com/pytorch/TensorRT/pull/1067 We refactored this part recently, and I tried your model with this PR, didn't find the issue you had.
I am still getting this runtime error on another model; I don't believe I'm using deformable convolution here. I'm in the process of trying to clean up source code to show, but the error is something like:
RuntimeError: [Error thrown at core/partitioning/shape_analysis.cpp:68] Expected ivalues_maps.count(input) to be true but got false Could not find torch::jit::Value* 1442 produced from %x.2 : Tensor, %1442 : Float(32, 12, 3, 3, ...
with the
...
replaced by about 25K characters of a dump of the graph.Note also that this error did not occur with float16, only with int8.
This is with Torch-TensorRT v1.1.0, so PR #1024 is included.
@Hodapp87 Could you please try this #1067 as well? Or could you please provide a reproducer if you still hit this issue?
@BrettRyland Could you please try this PR: #1067 We refactored this part recently, and I tried your model with this PR, didn't find the issue you had.
I still get this error
RuntimeError: [Error thrown at core/partitioning/shape_analysis.cpp:68] Expected ivalues_maps.count(input) to be true but got false
Could not find torch::jit::Value* 62 produced from %62 : int = aten::__getitem__(%61, %54) # /home/brett/.local/lib/python3.10/site-packages/torchvision/ops/deform_conv.py:71:28 in lowering graph for mini graph input.
(full log: trt_bug_log.txt) with PR #1067 merged in:
brett@br-workhorse:~/github/TensorRT/py$ git log --oneline --graph
* 22d91f5e (HEAD -> master) fix: fix the bug that tag Constant node as fallback node
* ccb826e7 Merge remote-tracking branch 'origin' into refactor_segmentation
|\
| * 91a92ca4 (origin/master, origin/HEAD) docs: [Automated] Regenerating documenation for dcf3386
brett@br-workhorse:/tmp$ python3 -c 'import torch_tensorrt; print(torch_tensorrt.__version__)'
1.2.0a0+22d91f5e
Side note: the trt_bug.py script has a typo on line 93, it should be out_trt2 = model2(tensor2)
, not out_trt2 = model2(tensor)
, but I guess you saw that if you got it to run without issues.
Another side note: I don't think it's relevant to this issue, but I get the warning
WARNING: [Torch-TensorRT TorchScript Conversion Context] - TensorRT was linked against cuBLAS/cuBLAS LT 11.8.0 but loaded cuBLAS/cuBLAS LT 111.0.1
despite not having cublas 11.8.0 on my system
brett@br-workhorse:/storage/github/TensorRT$ sudo updatedb && locate -i libcublas
/home/brett/.cache/bazel/_bazel_brett/f5a9c3674f5e1909bfdcbf1248e5df36/execroot/Torch-TensorRT/bazel-out/k8-opt/bin/_solib_k8/_U@cuda_S_S_Ccublas___Ulib64/libcublas.so
/home/brett/.cache/bazel/_bazel_brett/f5a9c3674f5e1909bfdcbf1248e5df36/execroot/Torch-TensorRT/bazel-out/k8-opt/bin/_solib_k8/_U@cuda_S_S_Ccublas___Ulib64_Sstubs/libcublas.so
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcublas.so
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcublas.so.11
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcublas.so.11.10.1.25
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcublasLt.so
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcublasLt.so.11
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcublasLt.so.11.10.1.25
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcublasLt_static.a
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcublas_static.a
/usr/local/cuda-11.7/targets/x86_64-linux/lib/stubs/libcublas.so
/usr/local/cuda-11.7/targets/x86_64-linux/lib/stubs/libcublasLt.so
/usr/share/doc/libcublas-11-7
/usr/share/doc/libcublas-dev-11-7
/usr/share/doc/libcublas-11-7/changelog.Debian.gz
/usr/share/doc/libcublas-11-7/copyright
/usr/share/doc/libcublas-dev-11-7/changelog.Debian.gz
/usr/share/doc/libcublas-dev-11-7/copyright
/var/lib/dpkg/info/libcublas-11-7.list
/var/lib/dpkg/info/libcublas-11-7.md5sums
/var/lib/dpkg/info/libcublas-dev-11-7.list
/var/lib/dpkg/info/libcublas-dev-11-7.md5sums
brett@br-workhorse:/storage/github/TensorRT$ ls -l /home/brett/.cache/bazel/_bazel_brett/f5a9c3674f5e1909bfdcbf1248e5df36/execroot/Torch-TensorRT/bazel-out/k8-opt/bin/_solib_k8/_U@cuda_S_S_Ccublas___Ulib64/libcublas.so
lrwxrwxrwx 1 brett brett 103 May 25 16:55 /home/brett/.cache/bazel/_bazel_brett/f5a9c3674f5e1909bfdcbf1248e5df36/execroot/Torch-TensorRT/bazel-out/k8-opt/bin/_solib_k8/_U@cuda_S_S_Ccublas___Ulib64/libcublas.so -> /home/brett/.cache/bazel/_bazel_brett/f5a9c3674f5e1909bfdcbf1248e5df36/external/cuda/lib64/libcublas.so
brett@br-workhorse:/storage/github/TensorRT$ ls -l /home/brett/.cache/bazel/_bazel_brett/f5a9c3674f5e1909bfdcbf1248e5df36/external/cuda/lib64/libcublas.so
lrwxrwxrwx 1 root root 15 Apr 6 04:07 /home/brett/.cache/bazel/_bazel_brett/f5a9c3674f5e1909bfdcbf1248e5df36/external/cuda/lib64/libcublas.so -> libcublas.so.11
brett@br-workhorse:/storage/github/TensorRT$ ls -l /home/brett/.cache/bazel/_bazel_brett/f5a9c3674f5e1909bfdcbf1248e5df36/external/cuda/lib64/libcublas.so.11
lrwxrwxrwx 1 root root 23 Apr 6 04:07 /home/brett/.cache/bazel/_bazel_brett/f5a9c3674f5e1909bfdcbf1248e5df36/external/cuda/lib64/libcublas.so.11 -> libcublas.so.11.10.1.25
brett@br-workhorse:/storage/github/TensorRT$ ls -l /home/brett/.cache/bazel/_bazel_brett/f5a9c3674f5e1909bfdcbf1248e5df36/external/cuda/lib64/libcublas.so.11.10.1.25
-rw-r--r-- 1 root root 156720544 Apr 6 04:07 /home/brett/.cache/bazel/_bazel_brett/f5a9c3674f5e1909bfdcbf1248e5df36/external/cuda/lib64/libcublas.so.11.10.1.25
@Hodapp87 Could you please try this #1067 as well? Or could you please provide a reproducer if you still hit this issue?
It still occurs for me too. I am trying to provide code that can reproduce, but much of this is proprietary in nature and so it may take some time to disentangle it.
@BrettRyland Could you please try this PR: #1067 We refactored this part recently, and I tried your model with this PR, didn't find the issue you had.
I still get this error
RuntimeError: [Error thrown at core/partitioning/shape_analysis.cpp:68] Expected ivalues_maps.count(input) to be true but got false Could not find torch::jit::Value* 62 produced from %62 : int = aten::__getitem__(%61, %54) # /home/brett/.local/lib/python3.10/site-packages/torchvision/ops/deform_conv.py:71:28 in lowering graph for mini graph input.
(full log: trt_bug_log.txt) with PR #1067 merged in:
brett@br-workhorse:~/github/TensorRT/py$ git log --oneline --graph * 22d91f5e (HEAD -> master) fix: fix the bug that tag Constant node as fallback node * ccb826e7 Merge remote-tracking branch 'origin' into refactor_segmentation |\ | * 91a92ca4 (origin/master, origin/HEAD) docs: [Automated] Regenerating documenation for dcf3386
brett@br-workhorse:/tmp$ python3 -c 'import torch_tensorrt; print(torch_tensorrt.__version__)' 1.2.0a0+22d91f5e
Side note: the trt_bug.py script has a typo on line 93, it should be
out_trt2 = model2(tensor2)
, notout_trt2 = model2(tensor)
, but I guess you saw that if you got it to run without issues.Another side note: I don't think it's relevant to this issue, but I get the warning
WARNING: [Torch-TensorRT TorchScript Conversion Context] - TensorRT was linked against cuBLAS/cuBLAS LT 11.8.0 but loaded cuBLAS/cuBLAS LT 111.0.1
despite not having cublas 11.8.0 on my system
brett@br-workhorse:/storage/github/TensorRT$ sudo updatedb && locate -i libcublas /home/brett/.cache/bazel/_bazel_brett/f5a9c3674f5e1909bfdcbf1248e5df36/execroot/Torch-TensorRT/bazel-out/k8-opt/bin/_solib_k8/_U@cuda_S_S_Ccublas___Ulib64/libcublas.so /home/brett/.cache/bazel/_bazel_brett/f5a9c3674f5e1909bfdcbf1248e5df36/execroot/Torch-TensorRT/bazel-out/k8-opt/bin/_solib_k8/_U@cuda_S_S_Ccublas___Ulib64_Sstubs/libcublas.so /usr/local/cuda-11.7/targets/x86_64-linux/lib/libcublas.so /usr/local/cuda-11.7/targets/x86_64-linux/lib/libcublas.so.11 /usr/local/cuda-11.7/targets/x86_64-linux/lib/libcublas.so.11.10.1.25 /usr/local/cuda-11.7/targets/x86_64-linux/lib/libcublasLt.so /usr/local/cuda-11.7/targets/x86_64-linux/lib/libcublasLt.so.11 /usr/local/cuda-11.7/targets/x86_64-linux/lib/libcublasLt.so.11.10.1.25 /usr/local/cuda-11.7/targets/x86_64-linux/lib/libcublasLt_static.a /usr/local/cuda-11.7/targets/x86_64-linux/lib/libcublas_static.a /usr/local/cuda-11.7/targets/x86_64-linux/lib/stubs/libcublas.so /usr/local/cuda-11.7/targets/x86_64-linux/lib/stubs/libcublasLt.so /usr/share/doc/libcublas-11-7 /usr/share/doc/libcublas-dev-11-7 /usr/share/doc/libcublas-11-7/changelog.Debian.gz /usr/share/doc/libcublas-11-7/copyright /usr/share/doc/libcublas-dev-11-7/changelog.Debian.gz /usr/share/doc/libcublas-dev-11-7/copyright /var/lib/dpkg/info/libcublas-11-7.list /var/lib/dpkg/info/libcublas-11-7.md5sums /var/lib/dpkg/info/libcublas-dev-11-7.list /var/lib/dpkg/info/libcublas-dev-11-7.md5sums brett@br-workhorse:/storage/github/TensorRT$ ls -l /home/brett/.cache/bazel/_bazel_brett/f5a9c3674f5e1909bfdcbf1248e5df36/execroot/Torch-TensorRT/bazel-out/k8-opt/bin/_solib_k8/_U@cuda_S_S_Ccublas___Ulib64/libcublas.so lrwxrwxrwx 1 brett brett 103 May 25 16:55 /home/brett/.cache/bazel/_bazel_brett/f5a9c3674f5e1909bfdcbf1248e5df36/execroot/Torch-TensorRT/bazel-out/k8-opt/bin/_solib_k8/_U@cuda_S_S_Ccublas___Ulib64/libcublas.so -> /home/brett/.cache/bazel/_bazel_brett/f5a9c3674f5e1909bfdcbf1248e5df36/external/cuda/lib64/libcublas.so brett@br-workhorse:/storage/github/TensorRT$ ls -l /home/brett/.cache/bazel/_bazel_brett/f5a9c3674f5e1909bfdcbf1248e5df36/external/cuda/lib64/libcublas.so lrwxrwxrwx 1 root root 15 Apr 6 04:07 /home/brett/.cache/bazel/_bazel_brett/f5a9c3674f5e1909bfdcbf1248e5df36/external/cuda/lib64/libcublas.so -> libcublas.so.11 brett@br-workhorse:/storage/github/TensorRT$ ls -l /home/brett/.cache/bazel/_bazel_brett/f5a9c3674f5e1909bfdcbf1248e5df36/external/cuda/lib64/libcublas.so.11 lrwxrwxrwx 1 root root 23 Apr 6 04:07 /home/brett/.cache/bazel/_bazel_brett/f5a9c3674f5e1909bfdcbf1248e5df36/external/cuda/lib64/libcublas.so.11 -> libcublas.so.11.10.1.25 brett@br-workhorse:/storage/github/TensorRT$ ls -l /home/brett/.cache/bazel/_bazel_brett/f5a9c3674f5e1909bfdcbf1248e5df36/external/cuda/lib64/libcublas.so.11.10.1.25 -rw-r--r-- 1 root root 156720544 Apr 6 04:07 /home/brett/.cache/bazel/_bazel_brett/f5a9c3674f5e1909bfdcbf1248e5df36/external/cuda/lib64/libcublas.so.11.10.1.25
@BrettRyland Did you clear your cache? I could get your model work after I used that PR.
You could do:
pip3 uninstall torch_tensorrt
to uninstall the previously installed torch_tensorrt. Do it multiple times to ensure that there isn't any library copies left.
Then:
python3 setup.py clean
python3 setup.py install
Could you please print some logs to make sure that the PR works? I also had some issues that I found in fact I wasn't using the merged code because I didn't clear the cache.
@Hodapp87 Can you try it as well?
For testing this, I used the Dockerfile straight out of the repository and then ran inside this container. Unless this caches something I'm unaware of, this should have been a clean build.
Here's the truncated output of my run, which shows versions as well (22d91f5e
should be your PR's commit):
------------------------------------------------------------
torch version: 1.11.0+cu102
torch_tensorrt version: 1.2.0a0+22d91f5e
------------------------------------------------------------
WARNING: [Torch-TensorRT] - Cannot infer input type from calcuations in graph for input x.2. Assuming it is Float32. If not, specify input type explicity
WARNING: [Torch-TensorRT] - Input type for doing shape analysis could not be determined, defaulting to F32
Traceback (most recent call last):
File "./compile.py", line 119, in <module>
compiled = torch_tensorrt.compile(mdl_ts, **compile_spec)
File "/usr/local/lib/python3.8/dist-packages/torch_tensorrt/_compile.py", line 115, in compile
return torch_tensorrt.ts.compile(ts_mod, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch_tensorrt/ts/_compiler.py", line 113, in compile
compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
RuntimeError: [Error thrown at core/partitioning/shape_analysis.cpp:68] Expected ivalues_maps.count(input) to be true but got false
Could not find torch::jit::Value* 1935 produced from %x.2 : Tensor, %1935 : Float(32, 12, 3, 3, strides=[108, 9, 3, 1], ...
If I get a chance soon, I'll see if I can extract a simpler model out of this that I can send to try.
@BrettRyland Did you clear your cache? I could get your model work after I used that PR. You could do:
pip3 uninstall torch_tensorrt
to uninstall the previously installed torch_tensorrt. Do it multiple times to ensure that there isn't any library copies left. Then:python3 setup.py clean
python3 setup.py install
Could you please print some logs to make sure that the PR works? I also had some issues that I found in fact I wasn't using the merged code because I didn't clear the cache. @Hodapp87 Can you try it as well?
Clearing the cache (I also ran bazel clean --expunge
in the top-level directory and removed ~/.cache/bazel
) allowed the test model to compile and run without problems, but my full model still gives the same RuntimeError
, just in a different place.
RuntimeError: [Error thrown at core/partitioning/shape_analysis.cpp:68] Expected ivalues_maps.count(input) to be true but got false
Could not find torch::jit::Value* 1162 produced from %x.6 : Tensor, %1162 : Long(requires_grad=0, device=cpu) = prim::Param() in lowering graph for mini graph input.
I'll need to try to isolate where in my full model it's happening now to see what's triggering it.
OK, I've reduced my model to a smaller repro script trt_bug.py which still gives
RuntimeError: [Error thrown at core/partitioning/shape_analysis.cpp:68] Expected ivalues_maps.count(input) to be true but got false
Could not find torch::jit::Value* 61 produced from %x.1 : Tensor, %61 : Long(requires_grad=0, device=cpu) = prim::Param() in lowering graph for mini graph input.
It appears to be being caused by using a single-valued int64
index tensor in an aten::index_put_
operation:
scores[:, self.anchor_always_index, :] = self.false_anchor_score
where self.false_anchor_score
is a registered buffer.
Replacing the index tensor with the value (using .item()
) causes a different error:
RuntimeError: [Error thrown at ./core/conversion/var/Var_inl.h:37] Expected isIValue() to be true but got false
Requested unwrapping of arg assuming it was an IValue, however arg type is nvinfer1::ITensor
which can also be avoided by using an explicit tensor instead of the self.false_anchor_score
buffer.
Note that torch will happily script or trace the model with
scripted_model = torch.jit.script(model)
or
traced_model = torch.jit.trace(model, torch.rand((1, *model.input_size), device=DEVICE))
OK, I've reduced my model to a smaller repro script trt_bug.py which still gives
RuntimeError: [Error thrown at core/partitioning/shape_analysis.cpp:68] Expected ivalues_maps.count(input) to be true but got false Could not find torch::jit::Value* 61 produced from %x.1 : Tensor, %61 : Long(requires_grad=0, device=cpu) = prim::Param() in lowering graph for mini graph input.
It appears to be being caused by using a single-valued
int64
index tensor in anaten::index_put_
operation:scores[:, self.anchor_always_index, :] = self.false_anchor_score
where
self.false_anchor_score
is a registered buffer. Replacing the index tensor with the value (using.item()
) causes a different error:RuntimeError: [Error thrown at ./core/conversion/var/Var_inl.h:37] Expected isIValue() to be true but got false Requested unwrapping of arg assuming it was an IValue, however arg type is nvinfer1::ITensor
which can also be avoided by using an explicit tensor instead of the
self.false_anchor_score
buffer. Note that torch will happily script or trace the model withscripted_model = torch.jit.script(model)
or
traced_model = torch.jit.trace(model, torch.rand((1, *model.input_size), device=DEVICE))
Hi @BrettRyland I took a look into your model. This line https://github.com/pytorch/TensorRT/blob/058a511ca0fd0df0587224f20e04067326d692e4/core/lowering/lowering.cpp#L95 introduces another input as %61 to the whole model's input, which might have something to do with the underlying function in PyTorch here https://github.com/pytorch/pytorch/blob/6114b0f921d5568c582d0168501f780df7a66d0d/torch/csrc/jit/passes/lower_graph.cpp#L149. This seems not an issue related to fallback, I'm now looking into it to figure out what happened.
For testing this, I used the Dockerfile straight out of the repository and then ran inside this container. Unless this caches something I'm unaware of, this should have been a clean build.
Here's the truncated output of my run, which shows versions as well (
22d91f5e
should be your PR's commit):------------------------------------------------------------ torch version: 1.11.0+cu102 torch_tensorrt version: 1.2.0a0+22d91f5e ------------------------------------------------------------ WARNING: [Torch-TensorRT] - Cannot infer input type from calcuations in graph for input x.2. Assuming it is Float32. If not, specify input type explicity WARNING: [Torch-TensorRT] - Input type for doing shape analysis could not be determined, defaulting to F32 Traceback (most recent call last): File "./compile.py", line 119, in <module> compiled = torch_tensorrt.compile(mdl_ts, **compile_spec) File "/usr/local/lib/python3.8/dist-packages/torch_tensorrt/_compile.py", line 115, in compile return torch_tensorrt.ts.compile(ts_mod, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs) File "/usr/local/lib/python3.8/dist-packages/torch_tensorrt/ts/_compiler.py", line 113, in compile compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec)) RuntimeError: [Error thrown at core/partitioning/shape_analysis.cpp:68] Expected ivalues_maps.count(input) to be true but got false Could not find torch::jit::Value* 1935 produced from %x.2 : Tensor, %1935 : Float(32, 12, 3, 3, strides=[108, 9, 3, 1], ...
If I get a chance soon, I'll see if I can extract a simpler model out of this that I can send to try.
can I get more details? thanks @Hodapp87
OK, I've reduced my model to a smaller repro script trt_bug.py which still gives
RuntimeError: [Error thrown at core/partitioning/shape_analysis.cpp:68] Expected ivalues_maps.count(input) to be true but got false Could not find torch::jit::Value* 61 produced from %x.1 : Tensor, %61 : Long(requires_grad=0, device=cpu) = prim::Param() in lowering graph for mini graph input.
It appears to be being caused by using a single-valued
int64
index tensor in anaten::index_put_
operation:scores[:, self.anchor_always_index, :] = self.false_anchor_score
where
self.false_anchor_score
is a registered buffer. Replacing the index tensor with the value (using.item()
) causes a different error:RuntimeError: [Error thrown at ./core/conversion/var/Var_inl.h:37] Expected isIValue() to be true but got false Requested unwrapping of arg assuming it was an IValue, however arg type is nvinfer1::ITensor
which can also be avoided by using an explicit tensor instead of the
self.false_anchor_score
buffer. Note that torch will happily script or trace the model withscripted_model = torch.jit.script(model)
or
traced_model = torch.jit.trace(model, torch.rand((1, *model.input_size), device=DEVICE))
hey @BrettRyland did you bypass the issue? After detailed investigation, it seems that this error comes from pytorch, and we can bypass it by explicit tensors when building models., though we can also do some work like post-processing the graph produced by PyTorch to erase the introduced missing input.
I managed to bypass the issue by using
anchor_always_score = scores[:, self.anchor_always_index, :] # Avoid indexed assignment, which quantisation doesn't like. https://discuss.pytorch.org/t/how-to-get-around-proxy-object-does-not-support-item-assignment/122655
anchor_always_score *= 0
anchor_always_score += self.false_anchor_score
instead of
scores[:, self.anchor_always_index, :] = self.false_anchor_score
as the latter was also causing issues with quantisation (which I've also been trying to get to work). However, I haven't yet managed to adjust my full model to be compatible with TensorRT or quantisation (I've been busy with other stuff), so I can't confirm whether it gives correct results for the full model yet, though it does seem to work for the trt_bug.py script.
I am having a similar issue converting a TorchScript model with TorchTensorRT. I see that the consensus is that you can refactor your code as a workaround to the issue, but it's not immediately clear where the error is coming from:
Traceback (most recent call last):
File "scripts/evaluation/infer_tables_tensorrt.py", line 42, in <module>
main()
File "scripts/evaluation/infer_tables_tensorrt.py", line 20, in main
torch_tensorrt.compile(model.backbone, inputs=[torch_tensorrt.Input((1, 3, 1440, 1440))])
File "/opt/conda/lib/python3.8/site-packages/torch_tensorrt/_compile.py", line 115, in compile
return torch_tensorrt.ts.compile(ts_mod, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/torch_tensorrt/ts/_compiler.py", line 113, in compile
compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
RuntimeError: [Error thrown at core/partitioning/shape_analysis.cpp:68] Expected ivalues_maps.count(input) to be true but got false
Could not find torch::jit::Value* 2403 produced from %x.1 : Tensor, %2374 : Float(512, 512, 3, 3, strides=[4608, 9, 3, 1], requires_grad=1, device=cuda:0), %2375 : Float(512, 512, 3, 3, strides=[4608, 9, 3, 1], requires_grad=1, device=cuda:0), %2376 : Float(512, 512, 3, 3, strides=[4608, 9, 3, 1], requires_grad=1, device=cuda:0), %2377 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2378 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2379 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2380 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2381 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2382 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2383 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2384 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2385 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2386 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2387 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2388 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2389 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2390 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2391 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2392 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2393 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2394 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2395 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2396 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2397 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2398 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2399 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2400 : Float(128, 128, 3, 3, strides=[1152, 9, 3, 1], requires_grad=1, device=cuda:0), %2401 : Float(128, 128, 3, 3, strides=[1152, 9, 3, 1], requires_grad=1, device=cuda:0), %2402 : Float(128, 128, 3, 3, strides=[1152, 9, 3, 1], requires_grad=1, device=cuda:0), %2403 : Float(128, 128, 3, 3, strides=[1152, 9, 3, 1], requires_grad=1, device=cuda:0) = prim::Param() in lowering graph for mini graph input.
I am using the latest official image 22.06: nvcr.io/nvidia/pytorch:22.06-py3
.
Unfortunately the code is not open source, I will try and see if I can get a small repro case.
Unfortunately the code is not open source, I will try and see if I can get a small repro case.
Hi @Belval, looks like ngc 22.06 does not contain our newest fix #1140 for this bug. You could share a small repro so I can take a look what's going on. Thanks.
Still working on the repro, but I just build torch_tensorrt from source (using master) to see if it helped and it it did partially solve the issue above. Unfortunately, I get a different Expected ivalues_maps.count(input) to be true but got false
error.
Here is the stacktrace for reference:
Traceback (most recent call last):
File "scripts/evaluation/infer_tables_tensorrt.py", line 44, in <module>
main()
File "scripts/evaluation/infer_tables_tensorrt.py", line 22, in main
torch_tensorrt.compile(model.backbone, inputs=[torch_tensorrt.Input((1, 3, 1440, 1440))])
File "/opt/conda/lib/python3.8/site-packages/torch_tensorrt/_compile.py", line 111, in compile
return torch_tensorrt.ts.compile(ts_mod, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/torch_tensorrt/ts/_compiler.py", line 113, in compile
compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
RuntimeError: [Error thrown at core/partitioning/shape_analysis.cpp:68] Expected ivalues_maps.count(input) to be true but got false
Could not find torch::jit::Value* 17979 produced from %17979 : int[] = prim::ListConstruct(%batch.8, %17960) in lowering graph for mini graph input.
Still working on the repro, but I just build torch_tensorrt from source (using master) to see if it helped and it it did partially solve the issue above. Unfortunately, I get a different
Expected ivalues_maps.count(input) to be true but got false
error.Here is the stacktrace for reference:
Traceback (most recent call last): File "scripts/evaluation/infer_tables_tensorrt.py", line 44, in <module> main() File "scripts/evaluation/infer_tables_tensorrt.py", line 22, in main torch_tensorrt.compile(model.backbone, inputs=[torch_tensorrt.Input((1, 3, 1440, 1440))]) File "/opt/conda/lib/python3.8/site-packages/torch_tensorrt/_compile.py", line 111, in compile return torch_tensorrt.ts.compile(ts_mod, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs) File "/opt/conda/lib/python3.8/site-packages/torch_tensorrt/ts/_compiler.py", line 113, in compile compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec)) RuntimeError: [Error thrown at core/partitioning/shape_analysis.cpp:68] Expected ivalues_maps.count(input) to be true but got false Could not find torch::jit::Value* 17979 produced from %17979 : int[] = prim::ListConstruct(%batch.8, %17960) in lowering graph for mini graph input.
Hey @Belval could you please try either add this line:
"torch_executed_ops": ["prim::ListConstruct"]
or set:
"min_block_size": 1
Details about why this happens could be found here: https://github.com/pytorch/TensorRT/issues/1173 Will raise a PR soon to cover these cases.
torch_tensorrt.compile(model, inputs=[torch_tensorrt.Input((1, 3, 1440, 1440))], torch_executed_ops=["prim::ListConstruct"])
returns a very similar stack trace as before:
RuntimeError: [Error thrown at core/partitioning/shape_analysis.cpp:68] Expected ivalues_maps.count(input) to be true but got falseCould not find torch::jit::Value* 1968 produced from %x.1 : Tensor, %1939 : Float(512, 512, 3, 3, strides=[4608, 9, 3, 1], requires_grad=1, device=cuda:0), %1940 : Float(512, 512, 3, 3, strides=[4608, 9, 3, 1], requires_grad=1, device=cuda:0), %1941 : Float(512, 512, 3, 3, strides=[4608, 9, 3, 1], requires_grad=1, device=cuda:0), %1942 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %1943 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %1944 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %1945 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %1946 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %1947 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %1948 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %1949 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %1950 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %1951 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %1952 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %1953 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %1954 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %1955 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %1956 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %1957 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %1958 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %1959 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %1960 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %1961 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %1962 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %1963 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %1964 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %1965 : Float(128, 128, 3, 3, strides=[1152, 9, 3, 1], requires_grad=1, device=cuda:0), %1966 : Float(128, 128, 3, 3, strides=[1152, 9, 3, 1], requires_grad=1, device=cuda:0), %1967 : Float(128, 128, 3, 3, strides=[1152, 9, 3, 1], requires_grad=1, device=cuda:0), %1968 : Float(128, 128, 3, 3, strides=[1152, 9, 3, 1], requires_grad=1, device=cuda:0) = prim::Param() in lowering graph for mini graph input.
Interestingly, torch_tensorrt.compile(model, inputs=[torch_tensorrt.Input((1, 3, 1440, 1440))], min_block_size=1)
works to some extent, but instead I get not exception, the compilation exits with an error code (1) without printing anything:
%15770 : int[] = prim::ListConstruct(%15715, %15769)
%15771 : int = prim::min(%15770) # /code/models/d2/detectron2/modeling/backbone/fpn.py:151:15
= prim::Loop(%15771, %self.bottom_up.stages_and_names.res2.2.conv2.use_bn.1) # /code/models/d2/detectron2/modeling/backbone/fpn.py:151:15
block0(%1936 : int):
%1938 : Tensor = aten::__getitem__(%1923, %1936) # /code/models/d2/detectron2/modeling/backbone/fpn.py:151:15
%10295 : str = aten::__getitem__(%self._out_features.1, %1936) # /code/models/d2/detectron2/modeling/backbone/fpn.py:151:15
= aten::_set_item(%1932, %10295, %1938) # /code/models/d2/detectron2/modeling/backbone/fpn.py:151:15
-> (%self.bottom_up.stages_and_names.res2.2.conv2.use_bn.1)
return (%1932)
:
Any ideas?
@Belval I was trying to reproduce the error on ngc 22.06. However, I kept getting library loading errors if I want to reproduce your error by building torch_tensorrt from source. Did I miss anything?
I am not sure that I understand your question. If you are referring to the repro package I sent you, then it could be the torch.ops.load_library
in repro.py
that does not have the correct path. libcustom_deform_conv.so
contains the compiled TorchScript operators.