onnxruntime_backend
onnxruntime_backend copied to clipboard
Not able to load simple iris model: Getting error: `Unsupported ONNX Type 'ONNX_TYPE_SEQUENCE'`
Description
Getting an error "failed to load 'model_onnx' version 1: Unsupported: Unsupported ONNX Type 'ONNX_TYPE_SEQUENCE' for I/O 'output_probability', expected 'ONNX_TYPE_TENSOR'"
Triton Information nvcr.io/nvidia/tritonserver:21.10-py3
Are you using the Triton container or did you build it yourself? Triton container
To Reproduce Steps to reproduce the behavior.
- Train an iris model or use model: https://github.com/guyroyse/redisai-iris/blob/main/iris.onnx
# Train a model.
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y)
clr = RandomForestClassifier()
clr.fit(X_train, y_train)
# Convert into ONNX format
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType
initial_type = [('float_input', FloatTensorType([None, 4]))]
onx = convert_sklearn(clr, initial_types=initial_type)
with open("rf_iris.onnx", "wb") as f:
f.write(onx.SerializeToString())
# Compute the prediction with ONNX Runtime
import onnxruntime as rt
import numpy
sess = rt.InferenceSession("rf_iris.onnx")
input_name = sess.get_inputs()[0].name
label_name = sess.get_outputs()[0].name
print(input_name)
print(label_name)
pred_onx = sess.run([label_name], {input_name: X_test.astype(numpy.float32)})[0]
- Try to run the iris model on triton with onnx as backend.
tritonserver --strict-model-config=false --model-repository=/models
{
"error": "load failed for model 'model_onnx': version 1: Unsupported: Unsupported ONNX Type 'ONNX_TYPE_SEQUENCE' for I/O 'output_probability', expected 'ONNX_TYPE_TENSOR'.;\n"
}
Describe the models (framework, inputs, outputs), ideally include the model configuration file (if using an ensemble include the model configuration file for that as well).
- model.onnx
- no config.txt is required as it must be auto generated in case of onnx.
Expected behavior model.onnx load should not fail.
Thank you for reporting this potential bug. We will investigate.
Hi @dyastremsky @CoderHam , Is this under investigation? Just wanted to confirm whether this could be because of the onnx version used for conversion of model. Also, Is there any workaround for this?
@KshitizLohia Yes, it's in our queue to investigate. This error could definitely be due to the ONNX version. What version did you use for conversion? The quickest solution would be to try using the most recent version of Triton server and checking whether your version of ONNX is supported by the current ONNX runtime.
@askhade Can you comment on this issue.
onnxruntime is able to load this model and my test with random test data was also successful. This error is coming from ort backend in triton: https://github.com/triton-inference-server/onnxruntime_backend/blob/main/src/onnxruntime_utils.cc#L161
@deadeyegoodwin can triton support sequence<map<>> ?
Yes with onnxruntime, we are able to load this model. But triton is not able to load this model implicitly.
@askhade I'm not familiar with what ONNX_TYPE_SEQUENCE is, but triton itself only supports tensors-in / tensors-out. It seems like one problem here is that the automatic configuration generator for onnxruntime backend should have failed for the output if it doesn't know what to do with ONNX_TYPE_SEQUENCE. Can the backend be updated to map ONNX_TYPE_SEQUENCE in/out to tensors?
Sequence in onnx is essentially a list. It can be a list of tensors, a list of list of tensors, list of map of tensors etc... In this case we have sequence<map<>>. Need to do some investigation to answer your question. Will update in a few.
Just for adding some additional information to the issue. Here the same error with the last Triton docker image (21.12-py3). I get the same error both for the RandomForestClassifier and a simple LogisticRegression:
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from skl2onnx.common.data_types import FloatTensorType
from skl2onnx import convert_sklearn
from sklearn.linear_model import LogisticRegression
iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y)
clr = LogisticRegression(max_iter=500)
clr.fit(X_train, y_train)
print(clr)
initial_type = [('float_input', FloatTensorType([None, 4]))]
onx = convert_sklearn(clr, initial_types=initial_type)
with open("lr_iris.onnx", "wb") as f:
f.write(onx.SerializeToString())
These are the versions of the packages to get the onnx file:
numpy: 1.21.5
scikit-learn: 0.23.2
onnx: 1.10.2
onnxruntime: 1.10.0
skl2onnx: 1.10.4
Hi I am also encountering the same issue. I tried both loading model created from skl2onnx implicitly and explicitly with a config file. Can someone here help please? Is there a fix already or workaround would also be very helpful. Thx @askhade @deadeyegoodwin @dyastremsky
I have similar issue but with ONNX_TYPE_MAP
having DictionaryType
converter from scikit-learn DictVectorizer
. Is there a plan to support this?
Apologies for not responding sooner. This is in our queue. Right now, only tensors are supported, but we are looking into the fix and feature request.
@KshitizLohia have you tried setting the option of zipmap=false
as seen here? http://onnx.ai/sklearn-onnx/auto_tutorial/plot_dbegin_options_zipmap.html#option-zipmap-false
This should allow it to have an array-like output which should be Triton compatible.