TensorRT
TensorRT copied to clipboard
🐛 [Bug] TensorRT-RTX BatchNorm constant fold got nan
Bug Description
Got nan result if batchnorm constant fold is enabled in RTX if disable the batchnorm constant fold in RTX, it is working as expected.
To Reproduce
Steps to reproduce the behavior:
- using lluo/tensorrt_rtx_python_try branch
- install rtx:
curl -L https://developer.nvidia.com/downloads/trt/rtx_sdk/secure/1.0/TensorRT-RTX-1.0.0.21.Linux.x86_64-gnu.cuda-12.9.tar.gz -o TensorRT-RTX-1.0.0.21.Linux.x86_64-gnu.cuda-12.9.tar.gz
tar -xzf TensorRT-RTX-1.0.0.21.Linux.x86_64-gnu.cuda-12.9.tar.gz
rtx_lib_dir=${PWD}/TensorRT-RTX-1.0.0.21/lib
export LD_LIBRARY_PATH=${rtx_lib_dir}:$LD_LIBRARY_PATH
echo "LD_LIBRARY_PATH: $LD_LIBRARY_PATH"
# make sure to change to your python version
pip install TensorRT-RTX-1.0.0.21/python/tensorrt_rtx-1.0.0.21-311-none-linux_x86_64.whl
- build rtx:
export PIP_EXTRA_INDEX_URL=https://download.pytorch.org/whl/nightly/cu129
export PIP_INDEX_URL=https://pypi.org/simple
python setup.py develop --use-rtx
python -m pip uninstall -y tensorrt tensorrt_cu12 tensorrt_cu12_bindings tensorrt_cu12_libs
export FORCE_TENSORRT_RTX=1
python test_batchnorm.txt
Test Code:
import torch
import torch_tensorrt
import torchvision.models as models
import sys
import os
dtype = torch.float32
class SimpleNetwork(torch.nn.Module):
def __init__(self):
super(SimpleNetwork, self).__init__()
self.conv1 = torch.nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = torch.nn.BatchNorm2d(64)
self.relu = torch.nn.ReLU(inplace=True)
self.maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
def forward(self, x):
# input: (64, 64, 64, 64)
x = self.conv1(x) # output: (64, 64, 64, 64)
x = self.bn1(x) # output: (64, 64, 64, 64)
x = self.relu(x) # output: (64, 64, 64, 64)
x = self.maxpool(x) # output: (64, 64, 32, 32)
x = self.conv1(x) # output: (64, 64, 32, 32)
x = self.bn1(x) # output: (64, 64, 32, 32) became nan
return x
model = SimpleNetwork().eval().to("cuda")
input = torch.randn((64, 64, 64, 64), dtype=dtype).to("cuda")
compile_spec = {
"device": torch_tensorrt.Device("cuda:0"),
"enabled_precisions": {dtype},
"ir": "dynamo",
"pass_through_build_failures": True,
#"optimization_level": 1, # default is 3 highest is 5
"min_block_size": 1,
"cache_built_engines": False,
"reuse_cached_engines": False,
"use_python_runtime": False,
}
exp_program = torch.export.export(model, (input,), strict=False)
if os.environ.get("FORCE_TENSORRT_RTX", "0") == "1":
DEBUG_LOGGING_DIR = "./tensorrt_rtx_debug_logs"
else:
DEBUG_LOGGING_DIR = "./tensorrt_debug_logs"
with torch_tensorrt.dynamo.Debugger(
"graphs",
logging_dir=DEBUG_LOGGING_DIR,
capture_fx_graph_after=["complex_graph_detection"],
save_engine_profile=True,
profile_format="trex",
engine_builder_monitor=True,
):
trt_mod = torch_tensorrt.dynamo.compile(exp_program, inputs=(input,), **compile_spec)
pyt_output = model(input)
trt_output = trt_mod(input)
abs_diff = torch.abs(pyt_output - trt_output)
print(f"{abs_diff.max().item()=} {abs_diff.mean().item()=} {trt_output.shape=}")
assert torch.allclose(pyt_output, trt_output, atol=1e-3, rtol=1e-3)