keras-onnx icon indicating copy to clipboard operation
keras-onnx copied to clipboard

Behaviour of stateful RNN/LSTM/GRU in onnx

Open srikanthderebail opened this issue 4 years ago • 3 comments

Hello everyone,

I am trying to export a Keras model that has a SimpleRNN layer with stateful=True. This ensures that the last state is kept between subsequent samples. However, when I export such a model using keras2onnx and evaluate with onnxruntime, the behaviour is different and it looks like the state is reset between subsequent samples.

Is there a way to have stateful behaviour in the converted onnx model ??

srikanthderebail avatar Sep 24 '20 14:09 srikanthderebail

The onnx ops GRU, LSTM does not have stateful attribute, so the keras layer cannot get directly mapped to onnx op. It is relatively hard for the converter to construct this op. Feel like the best way is to propose in onnx repo asking to add this stateful attribute.

jiafatom avatar Sep 28 '20 15:09 jiafatom

One solution may be to add input to the model to pass state, and add output to model where model can return state. Then you can manage statefulness in your own code.

dkloving avatar Oct 16 '20 19:10 dkloving

Hello,

I have converted the sequential model presented in the Tensorflow/Keras textgeneration tutorial into a Keras functional model, which enables me to pass the RNN (in my case a GRU cell) state as input parameter and retrieve the RNN state as output parameter. In the end it is working now, but I have encountered two issues causing keras2onnx to fail:

  • I had to specify unroll=True in the GRU layer, otherwise a get the following error message: ValueError: Unable to find out a correct type for tensor type = <dtype: 'variant'> of gru_1/while:3

  • I had to adjust the list index in line 400 (or in line 394 in release 1.7.0 which I am using) in build_output_states() in keras2onnx/ke2onnx/simplernn.py from 1 to 0:

 if output_state:
            output_h = operator.outputs[0].full_name
            apply_squeeze(scope, rnn_h, output_h, container)

Otherwise I get a "index out of range" error.

I'm not an experienced machine learning and python developer, so I may have done something totally stupid causing those problems. I'm using Keras 2.4 from Tensorflow 2.3.1, keras2onnx 1.7 and python 3.6.8.

The source code for reproducing that issue:

import tensorflow as tf
from tensorflow.python.keras import layers
import keras2onnx

def createModel(input_dim, batch_size, seq_len, embedding_dim=32, rnn_state_size=48, include_RNN_state=False):
    inputs = tf.keras.Input(shape=(seq_len,), dtype=tf.int32, name="Input", batch_size=batch_size)
    rnn_inputs = None

    if include_RNN_state:
        rnn_inputs = tf.keras.Input(shape=rnn_state_size, dtype=tf.float32, name="Input_RNN", batch_size=batch_size)

    embedding = layers.Embedding(input_dim=input_dim, output_dim=embedding_dim, batch_size=batch_size)
    x0 = embedding(inputs)

    rnn_outputs = None
    x1 = None
    if include_RNN_state:
        rnn = layers.GRU(rnn_state_size, batch_size=batch_size, stateful=True, return_state=True, return_sequences=True, unroll=True, recurrent_initializer='glorot_uniform')
        x1, gru_state = rnn(x0, initial_state=rnn_inputs)
        rnn_outputs = gru_state
    else:
        rnn = layers.GRU(rnn_state_size, batch_size=batch_size, stateful=True, return_sequences=True, recurrent_initializer='glorot_uniform')
        x1 = rnn(x0)

    dense = layers.Dense(input_dim)
    outputs = dense(x1)

    model = None
    if include_RNN_state:
        model = tf.keras.Model(inputs=[inputs, rnn_inputs], outputs=[outputs, rnn_outputs], name="char_predict_rnn")
    else:
        model = tf.keras.Model(inputs=inputs, outputs=outputs, name="char_predict_rnn")
    return model


model_without_rnnstate = createModel(10, 1, 1)
onnx_model = keras2onnx.convert_keras(model_without_rnnstate, model_without_rnnstate.name, debug_mode=True)

model_with_rnnstate = createModel(10, 1, 1, include_RNN_state=True)
onnx_model_stateful = keras2onnx.convert_keras(model_with_rnnstate, model_without_rnnstate.name, debug_mode=True)

helmutbressler avatar Oct 22 '20 14:10 helmutbressler