tensorflow-onnx
tensorflow-onnx copied to clipboard
Cannot export model using an LSTM layer within a tf.while_loop or a tf.cond
Describe the bug The library fails to export a model that uses a LSTM layer within a loop. This was relevant for encoding an autoregressive generation loop within an ONNX model. Using the cell of the layer works, which can be enough for models that generate one output (or a fixed number of outputs) at a time.
Similarly, it fails when trying to use the layer in a conditional construct.
Urgency None, I could workaround the issue in my case by just using the layer cell.
System information
- OS Platform and Distribution (e.g., Linux Ubuntu 18.04*): Windows 10
- TensorFlow Version: 2.12
- Python version: 3.10
- ONNX version (if applicable, e.g. 1.11*): 1.14.0
- ONNXRuntime version (if applicable, e.g. 1.11*): NA
To Reproduce The following code demonstrates the issue:
import tensorflow as tf
import tf2onnx
lstm = tf.keras.layers.LSTM(16, return_state=True)
lstm.build(tf.TensorShape([None, None, lstm.units]))
input_signature = [tf.TensorSpec([None, lstm.units])]
# Works
@tf.function
def run_lstm_once(x):
return lstm(tf.expand_dims(x, 1))[0]
_ = tf2onnx.convert.from_function(run_lstm_once, input_signature)
# Works
@tf.function
def run_lstm_in_loop_with_cell(x):
initial_state = lstm.get_initial_state(x)
y, _ = tf.while_loop(cond=lambda *_: True,
body=lambda x, state: lstm.cell(x, state),
loop_vars=[x, initial_state])
return y
_ = tf2onnx.convert.from_function(run_lstm_in_loop_with_cell, input_signature)
# Fails: rewriter <function rewriter_lstm_tf2 at ...>: exception get tensor value: 'while/lstm/Read/ReadVariableOp' must be Const
@tf.function
def run_lstm_in_loop(x):
initial_state = lstm.get_initial_state(x)
y, *_ = tf.while_loop(cond=lambda *_: True,
body=lambda x, *state: lstm(tf.expand_dims(x, 1), initial_state=state),
loop_vars=[x, *initial_state])
return y
_ = tf2onnx.convert.from_function(run_lstm_in_loop, input_signature)
# Fails: rewriter <function rewriter_lstm_tf2 at ...>: exception get tensor value: 'cond/lstm/Read/ReadVariableOp_tfg_inlined_cond_0'' must be Const
@tf.function
def run_lstm_in_cond(x):
y = tf.cond(tf.constant(True), lambda: lstm(tf.expand_dims(x, 1))[0], lambda: x)
return y
_ = tf2onnx.convert.from_function(run_lstm_in_cond, input_signature)
Screenshots NA
Additional context It seems the LSTM rewriter assumes that the LSTM operations will be directly connected to the constants resulting from freezing the model variables. However, when the layer is called from within a loop or a conditional, the LSTM operations are created within an independent function in the model, and the layer weights, which are external to the function, appear as input parameters to the function.
I suppose the solution would be to replicate the weight manipulation operations that are currently done at compile time by the rewriter as ONNX operations for the case where the weights are not constants (or just for every case, and rely on graph optimization to fold those operations).