TensorRT icon indicating copy to clipboard operation
TensorRT copied to clipboard

The onnx parser failed to parse a valid model: Slice (importSlice): INVALID_NODE: Assertion failed: (starts.size() == axes.size())

Open coffezhou opened this issue 7 months ago • 2 comments

Description

For the following valid onnx model,

Image it cannot be imported by the onnx frontend in TensorRT. The following error message is produced:

[05/29/2025-10:49:12] [TRT] [E] In node 5 with name:  and operator: Slice (importSlice): INVALID_NODE: Assertion failed: (starts.size() == axes.size()): The shape of input starts misaligns with the shape of input axes. Shape of input starts = 1, shape of input axes = 2.
In node 5 with name:  and operator: Slice (importSlice): INVALID_NODE: Assertion failed: (starts.size() == axes.size()): The shape of input starts misaligns with the shape of input axes. Shape of input starts = 1, shape of input axes = 2.

However, this model can be executed by onnxruntime. The output is as follows:

ONNXRuntime:
 [array([[1., 1., 1.]], dtype=float32)]

Environment

TensorRT Version: 10.11.0.33

NVIDIA GPU: GeForce RTX 3080

NVIDIA Driver Version: 535.183.01

CUDA Version: 12.2

CUDNN Version: none

Operating System: ubuntu 20.04

Python Version (if applicable): 3.12.9

Tensorflow Version (if applicable): none

PyTorch Version (if applicable): none

Baremetal or Container (if so, version): none

Steps To Reproduce

This bug can be reproduced by the following code with the model in the attachment. As shown in the code, the model can be executed by onnxruntime.

from typing import Dict, List, Literal, Optional
import sys
import os

import numpy as np
import onnx
import onnxruntime
from onnx import ModelProto, TensorProto, helper, mapping

import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit

import argparse
import pickle


def test():
    onnx_model = onnx.load("111.onnx")
    
    with open("inputs.pkl", "rb") as fp:
        inputs = pickle.load(fp)

    try:
        ort_session = onnxruntime.InferenceSession(
            onnx_model.SerializeToString(), providers=["CPUExecutionProvider"]
        )
        ort_output = ort_session.run([], inputs)
    except Exception as e:
        print(e)
        print("This model cannot be executed by onnxruntime!")
        sys.exit(1)
    
    print("ONNXRuntime:\n", ort_output)
    
    #--------------------------------------------------------
        
    trt_logger = trt.Logger(trt.Logger.WARNING)
    trt.init_libnvinfer_plugins(trt_logger, '')
    builder = trt.Builder(trt_logger)
    network = builder.create_network(flags=1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))

    parser = trt.OnnxParser(network, trt_logger)
    with open("111.onnx", 'rb') as model_file:
        if not parser.parse(model_file.read()):
            for error in range(parser.num_errors):
                print(parser.get_error(error))
            
    
if __name__ == "__main__":
    test()

testcase.zip

Commands or scripts:

Have you tried the latest release?: yes

Can this model run on other frameworks? For example run ONNX model with ONNXRuntime (polygraphy run <model.onnx> --onnxrt): the mode can be executed by onnxruntime.

coffezhou avatar May 29 '25 02:05 coffezhou

[05/29/2025-10:49:12] [TRT] [E] In node 5 with name: and operator: Slice (importSlice): INVALID_NODE: Assertion failed: (starts.size() == axes.size()): The shape of input starts misaligns with the shape of input axes. Shape of input starts = 1, shape of input axes = 2. In node 5 with name: and operator: Slice (importSlice): INVALID_NODE: Assertion failed: (starts.size() == axes.size()): The shape of input starts misaligns with the shape of input axes. Shape of input starts = 1, shape of input axes = 2.

As the error suggested, you have different shapes for input starts and axes. See also the ONNX spec for Slice. Other frameworks might have different results for undefined behavior. Pure guessing for your case, you want starts[i] to be the same for both axes since you've provided two values to the axes input. But if you want a different behavior, you might need to adjust.

cc @kevinch-nv

poweiw avatar Jun 04 '25 00:06 poweiw

I think the bug is here: https://github.com/onnx/onnx-tensorrt/blob/745bde22c2fe883968cf18cc9ebdfb2e2985166d/onnxOpImporters.cpp#L5652

        // "If axes are omitted, they are set to [0, ..., ndim-1]."
        axes = nbInputs > 3 && !inputs.at(3).isNullTensor() ? ShapeTensor(ctx, inputs.at(3))
                                                            : iotaShapeVector(dims.size());

Instead of defaulting the axes to a length of dims.size() it should be starts.size(), i.e. iotaShapeVector(starts.size())

pranavm-nvidia avatar Jun 06 '25 17:06 pranavm-nvidia