tensorflow-onnx
tensorflow-onnx copied to clipboard
Failed to convert LSTM cell with activation and dropout settings
Describe the bug Keras LSTM model setting activation function and dropout is decomposed instead of using ONNX LSTM cell.
System information
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Fedora 35, x86
- Tensorflow Version: 2.9.1
- Python version: 3.9.13
- tf2onnx: e896723 (2022-07-15)
To Reproduce
from keras.layers import Input, LSTM
from keras.models import Sequential
import tf2onnx
model = Sequential()
model.add(Input(shape=(16, 16)))
model.add(LSTM(16, activation='sigmoid', dropout=0.1))
model.save('test.h5')
onnx_model = tf2onnx.convert.from_keras(model, opset=16, output_path='test.onnx')
Removing either activation='sigmoid' or dropout=0.1 generates an LSTM cell in the output as expected.
Screenshots
test.h5:
test.onnx

This problem appears to be caused by a bogus Identity node visible as input for the LSTM nodes instead of the expected TensorListGetItem:
The following sequence is passed to the tensorflow optimizer:
The multiplication used as the input for the matmul is a NOP. While the optimizer recognizes this and turns it into an Identity node, it fails to get rid of it entirely because it is intentionally invoked without 'dependency' optimizations:
Changing the tensorflow optimizer invokation to:
# TODO: if we turn on pruning, grappler removes some identities that the tf-1.x lstm rewriter
# depends on so for now don't turn this on, constfold is always enabled now.
rewrite_options.optimizers[:] = [
# 'pruning', 'constfold', 'arithmetic', 'dependency', 'function',
'constfold', 'function', 'dependency'
]
Generates the following:

With that the graph matcher does not trip over the Identity node anymore and generates and LSTM as expected.
The TODO claims that enabling the dependency optimization prevents other LSTMs from being recognized. We probably should rather fix those cases instead of turning the optimization off.
tf2onnx is not expected supporting converting models for training so far, such as setting dropout.
But in this case, it seems like the root cause is from the tensorflow OptimizeGraph. We will consider to add dependency to optimization.