keras-onnx
keras-onnx copied to clipboard
Behaviour of stateful RNN/LSTM/GRU in onnx
Hello everyone,
I am trying to export a Keras model that has a SimpleRNN layer with stateful=True. This ensures that the last state is kept between subsequent samples. However, when I export such a model using keras2onnx and evaluate with onnxruntime, the behaviour is different and it looks like the state is reset between subsequent samples.
Is there a way to have stateful behaviour in the converted onnx model ??
The onnx ops GRU, LSTM does not have stateful attribute, so the keras layer cannot get directly mapped to onnx op. It is relatively hard for the converter to construct this op. Feel like the best way is to propose in onnx repo asking to add this stateful attribute.
One solution may be to add input to the model to pass state, and add output to model where model can return state. Then you can manage statefulness in your own code.
Hello,
I have converted the sequential model presented in the Tensorflow/Keras textgeneration tutorial into a Keras functional model, which enables me to pass the RNN (in my case a GRU cell) state as input parameter and retrieve the RNN state as output parameter. In the end it is working now, but I have encountered two issues causing keras2onnx to fail:
-
I had to specify unroll=True in the GRU layer, otherwise a get the following error message: ValueError: Unable to find out a correct type for tensor type = <dtype: 'variant'> of gru_1/while:3
-
I had to adjust the list index in line 400 (or in line 394 in release 1.7.0 which I am using) in build_output_states() in keras2onnx/ke2onnx/simplernn.py from 1 to 0:
if output_state:
output_h = operator.outputs[0].full_name
apply_squeeze(scope, rnn_h, output_h, container)
Otherwise I get a "index out of range" error.
I'm not an experienced machine learning and python developer, so I may have done something totally stupid causing those problems. I'm using Keras 2.4 from Tensorflow 2.3.1, keras2onnx 1.7 and python 3.6.8.
The source code for reproducing that issue:
import tensorflow as tf
from tensorflow.python.keras import layers
import keras2onnx
def createModel(input_dim, batch_size, seq_len, embedding_dim=32, rnn_state_size=48, include_RNN_state=False):
inputs = tf.keras.Input(shape=(seq_len,), dtype=tf.int32, name="Input", batch_size=batch_size)
rnn_inputs = None
if include_RNN_state:
rnn_inputs = tf.keras.Input(shape=rnn_state_size, dtype=tf.float32, name="Input_RNN", batch_size=batch_size)
embedding = layers.Embedding(input_dim=input_dim, output_dim=embedding_dim, batch_size=batch_size)
x0 = embedding(inputs)
rnn_outputs = None
x1 = None
if include_RNN_state:
rnn = layers.GRU(rnn_state_size, batch_size=batch_size, stateful=True, return_state=True, return_sequences=True, unroll=True, recurrent_initializer='glorot_uniform')
x1, gru_state = rnn(x0, initial_state=rnn_inputs)
rnn_outputs = gru_state
else:
rnn = layers.GRU(rnn_state_size, batch_size=batch_size, stateful=True, return_sequences=True, recurrent_initializer='glorot_uniform')
x1 = rnn(x0)
dense = layers.Dense(input_dim)
outputs = dense(x1)
model = None
if include_RNN_state:
model = tf.keras.Model(inputs=[inputs, rnn_inputs], outputs=[outputs, rnn_outputs], name="char_predict_rnn")
else:
model = tf.keras.Model(inputs=inputs, outputs=outputs, name="char_predict_rnn")
return model
model_without_rnnstate = createModel(10, 1, 1)
onnx_model = keras2onnx.convert_keras(model_without_rnnstate, model_without_rnnstate.name, debug_mode=True)
model_with_rnnstate = createModel(10, 1, 1, include_RNN_state=True)
onnx_model_stateful = keras2onnx.convert_keras(model_with_rnnstate, model_without_rnnstate.name, debug_mode=True)