TensorRT icon indicating copy to clipboard operation
TensorRT copied to clipboard

`Reshape` & `Cast` nodes not supported with `onnxruntime==1.15.1` & `onnx=1.14.0`, work with `onnxruntime==1.16.1`

Open pvardanis opened this issue 1 year ago • 5 comments

I'm using onnxruntime==1.15.1 & onnx=1.14.0 with onnx-graphsurgeon==0.5.2.

I'm modifying the output of an onnx graph of an XGBoost model using the Reshape & Graph operators as follows respectively:

    variable.name = "old_probabilities"

    class_1_probs = gs.Variable(
        "class_1_probs",
        dtype=np.float32,
        shape=[None],
    )
    indices = gs.Constant("indices", np.array([1]).astype(np.int64))

    gather_node = gs.Node(
        op="Gather",
        name="gather_node",
        inputs=[variable, indices],
        outputs=[class_1_probs],
    )
    gather_node.attrs["axis"] = 1
    graph.nodes.append(gather_node)

    reshaped_class_1_probs = gs.Variable(
        "probabilities",
        dtype=np.float32,
        shape=[None],
    )
    shape = gs.Constant("indices", np.array([-1]).astype(np.int64))

    reshape_node = gs.Node(
        op="Reshape",
        name="reshape_node",
        inputs=[class_1_probs, shape],
        outputs=[reshaped_class_1_probs],
    )

    graph.nodes.append(reshape_node)
    graph.outputs = [reshaped_class_1_probs]
# cast `label` from int64 to float32
    float_predictions = gs.Variable(
        "predictions",
        dtype=np.float32,
        shape=[None],
    )

    cast_node = gs.Node(
        op="Cast",
        name="cast_output",
        inputs=[variable],
        outputs=[float_predictions],
        attrs={"to": 1},  # 1 corresponds to float32
    )

    graph.nodes.append(cast_node)
    graph.outputs = [float_predictions]

the model exports successfully, but fails to load using:

session = rt.InferenceSession("model.onnx")

and raises the following errors:

onnxruntime.capi.onnxruntime_pybind11_state.NotImplemented: [ONNXRuntimeError] : 9 : NOT_IMPLEMENTED : Could not find an implementation for Reshape(19) node with name 'reshape_node'
onnxruntime.capi.onnxruntime_pybind11_state.NotImplemented: [ONNXRuntimeError] : 9 : NOT_IMPLEMENTED : Could not find an implementation for Cast(19) node with name 'cast_output'

Weird thing is, everything works fine with onnxruntime==1.16.1. Unfortunately, I'm restricted to use 1.15.1 and need to find a workaround for this.

pvardanis avatar Oct 02 '24 11:10 pvardanis