TensorRT
TensorRT copied to clipboard
`Reshape` & `Cast` nodes not supported with `onnxruntime==1.15.1` & `onnx=1.14.0`, work with `onnxruntime==1.16.1`
I'm using onnxruntime==1.15.1 & onnx=1.14.0 with onnx-graphsurgeon==0.5.2.
I'm modifying the output of an onnx graph of an XGBoost model using the Reshape & Graph operators as follows respectively:
variable.name = "old_probabilities"
class_1_probs = gs.Variable(
"class_1_probs",
dtype=np.float32,
shape=[None],
)
indices = gs.Constant("indices", np.array([1]).astype(np.int64))
gather_node = gs.Node(
op="Gather",
name="gather_node",
inputs=[variable, indices],
outputs=[class_1_probs],
)
gather_node.attrs["axis"] = 1
graph.nodes.append(gather_node)
reshaped_class_1_probs = gs.Variable(
"probabilities",
dtype=np.float32,
shape=[None],
)
shape = gs.Constant("indices", np.array([-1]).astype(np.int64))
reshape_node = gs.Node(
op="Reshape",
name="reshape_node",
inputs=[class_1_probs, shape],
outputs=[reshaped_class_1_probs],
)
graph.nodes.append(reshape_node)
graph.outputs = [reshaped_class_1_probs]
# cast `label` from int64 to float32
float_predictions = gs.Variable(
"predictions",
dtype=np.float32,
shape=[None],
)
cast_node = gs.Node(
op="Cast",
name="cast_output",
inputs=[variable],
outputs=[float_predictions],
attrs={"to": 1}, # 1 corresponds to float32
)
graph.nodes.append(cast_node)
graph.outputs = [float_predictions]
the model exports successfully, but fails to load using:
session = rt.InferenceSession("model.onnx")
and raises the following errors:
onnxruntime.capi.onnxruntime_pybind11_state.NotImplemented: [ONNXRuntimeError] : 9 : NOT_IMPLEMENTED : Could not find an implementation for Reshape(19) node with name 'reshape_node'
onnxruntime.capi.onnxruntime_pybind11_state.NotImplemented: [ONNXRuntimeError] : 9 : NOT_IMPLEMENTED : Could not find an implementation for Cast(19) node with name 'cast_output'
Weird thing is, everything works fine with onnxruntime==1.16.1. Unfortunately, I'm restricted to use 1.15.1 and need to find a workaround for this.