server icon indicating copy to clipboard operation
server copied to clipboard

[BUG] Serving tensorrt model with CUDA graph results in weird unconsistent outputs.

Open WingEdge777 opened this issue 1 month ago • 3 comments

Description When serving a TensorRT engine with CUDA graph optimization enabled, we encountered a weird phenomenon.

We send requests sequentially, following the AAAAABBBBBAAAABBBB pattern. In every A(B)‘s requesting period, the first few A(B) requests probably return the last period round B(A)'s results, which is absurdly wrong for the current input.

Image

However, if we remove the CUDA graph optimization config. All the results become consistent and correct with respect to A(B).

Image

Thus, we highly suspect that Triton or tensorrt backend or TRT itself is reusing the wrong/dirty/uninitialized inputs/buffers, etc.

CUDA graph is an essential optimization for many user cases. This issue prevents us from upgrading CUDA/TRT/tritonserver etc.

Triton Information I am using NVIDIA NGC TritonServer 25.10, which uses TritonServer 2.62.0 and TensorRT 10.13 according to the release notes. I'm running the server on L20/A10 GPUs along with Nvidia driver 535.161.08, of course, using the NVIDIA compat lib.

To Reproduce The reproduction instruction is simple. Export a ResNet 50 engine, and serve it with CUDA graph optimization. Then send requests to the server.

Refer to the code script as follows:

1、export onnx:

import torch
import torchvision.models as models


model = models.resnet50(weights=None)

model.eval()

dummy_input = torch.randn(1, 3, 224, 224)

output_onnx_file = "resnet50.onnx"

torch.onnx.export(
    model,
    dummy_input,
    output_onnx_file,
    export_params=True,
    opset_version=17,
    do_constant_folding=True,
    dynamo=False,
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
)

2、build TensorRT engine:

polygraphy convert ./resnet50.onnx --convert-to trt --output ./resnet50.plan \
    --fp16 \
    --trt-min-shapes input:[1,3,224,224]  \
    --trt-opt-shapes input:[4,3,224,224]   \
    --trt-max-shapes input:[8,3,224,224]

3、setup model_zoo and model config config.pbtxt:

name: "resnet_50"
backend: "tensorrt"
max_batch_size: 8
model_warmup: {
    name: "sample"
    batch_size: 1
    inputs: {
        key: "input"
        value: {
            data_type: TYPE_FP32,
            dims: [3, 224, 224],
	        zero_data: true
        }
    }
}

optimization{
   graph: {
       level : 1
   },
   eager_batching : 1,
   cuda: {
       graphs: 1,
       graph_spec: [
            { batch_size: 1 },
            { batch_size: 2 },
            { batch_size: 3 },
            { batch_size: 4 },
            { batch_size: 5 },
            { batch_size: 6 },
            { batch_size: 7 },
            { batch_size: 8 }
        ]
       busy_wait_events:1,
       output_copy_stream: 1
   }
}


dynamic_batching {
  preferred_batch_size: [4,8]
  max_queue_delay_microseconds: 2000
}
instance_group [ { count: 2 kind: KIND_GPU gpus:[0]}]

4、start server :

tritonserver --strict-model-config=0 --metrics-port=8102 --http-port=8100 --grpc-port=8101 --model-repository=./model_zoo --log-verbose=0 --backend-config=python,shm-default-byte-size=335544320

5、client sending requests:

import argparse
import numpy as np
import sys

import tritonclient.http as httpclient

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-v',
                        '--verbose',
                        action="store_true",
                        required=False,
                        default=False,
                        help='Enable verbose output')
    parser.add_argument('-u',
                        '--url',
                        type=str,
                        required=False,
                        default='localhost:8100',
                        help='Inference server URL. Default is localhost:8100.')

    FLAGS = parser.parse_args()
    request_count = 50
    try:
        triton_client = httpclient.InferenceServerClient(
            url=FLAGS.url, verbose=FLAGS.verbose, concurrency=request_count)
    except Exception as e:
        print("channel creation failed: " + str(e))
        sys.exit()

    ################################################### img check
    model_name = "resnet_50"

    output_name = ["output"]
    np.random.seed(1024)
    input_data = np.random.randn(2, 3, 224, 224).astype(np.float32)
    # input0_data = load_image("bag.jpeg")
    print(input_data.shape)

    for i in range(5):
        batch_id = i % 2
        # Infer
        inputs = []
        outputs = []
        
        inputs.append(httpclient.InferInput('input', input_data[batch_id:batch_id+1].shape, "FP32"))


        # Create the data for the two input tensors. Initialize the first
        # to unique integers and the second to all ones.

        # Initialize the data
        inputs[0].set_data_from_numpy(input_data[batch_id:batch_id+1])

        for name in output_name:
            outputs.append(httpclient.InferRequestedOutput(name))
        import time

        headers = {}
        st = time.time()
        cnt = 7
        results = []
        print(f"round {i}")
        for j in range(cnt):
            async_request = triton_client.async_infer(model_name=model_name,
                                        inputs=inputs,
                                        outputs=outputs, headers=headers)
            result = async_request.get_result()
            for name in output_name:
                out = result.as_numpy(name)
                print(f"input case : {batch_id}", ", output: ", name, out.shape, out[0][0])
                break

Expected behavior The server should output consistent and correct results.

WingEdge777 avatar Nov 28 '25 03:11 WingEdge777

We performed a test using Triton Server 25.06 (based on CUDA 12.9). The anomalous results persist, which suggests that the issue is likely unrelated to NVIDIA CUDA driver compatibility. Theoretically, CUDA 12.x should function correctly with driver 535.

Image

Image

WingEdge777 avatar Nov 28 '25 06:11 WingEdge777

The root cause has been identified. Disabling output_copy_stream resolves the issue. It appears that the output copy stream does not correctly synchronize with the CUDA Graph execution stream.

We look forward to a fix in a future release.

WingEdge777 avatar Dec 02 '25 03:12 WingEdge777

thank you for reporting this and with such an in-depth investigation.

whoisj avatar Dec 02 '25 19:12 whoisj