returnn
returnn copied to clipboard
Export to ONNX
Hi
Is there a way to convert pretrained returnn
networks to ONNX
or at least save the network to tensorflow's saved model
format?
Best Musharraf
Yes. @Gerstenberger has some experience doing this. See for example https://github.com/rwth-i6/returnn/issues/1236#issuecomment-1339201812:
yes, i use
compile_tf_graph.py
[from RETURNNtools/
] for this and then call the tf2onnx tool on the resulting graph.
Hi,
sorry for the delayed response. I was not available during September.
I remembered today that there was a question regarding ONNX
exporting. In addition, i also have more information about this topic now.
TensorFlow to ONNX
Checkpoint to SavedModel
You can use a script which does something like this:
Python Script - Expand Me
import argparse
import logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
stream_handler = logging.StreamHandler()
logger.addHandler(stream_handler)
# parse arguments before loading TF so we don't have to wait
parser = argparse.ArgumentParser(add_help=False)
required = parser.add_argument_group("required arguments")
optional = parser.add_argument_group("optional arguments")
required.add_argument("-i", "--input", help="checkpoint path", required=True)
required.add_argument("-o", "--output", help="output path", required=True)
optional.add_argument(
"-t",
"--output_names",
default="output/output_batch_major:0",
help="output tensor names separated by comma [default: %(default)s]",
)
optional.add_argument("-m", "--meta_graph", help="meta graph path")
optional.add_argument("-h", "--help", action="help", default=argparse.SUPPRESS, help="show this help message")
args = parser.parse_args()
import tensorflow as tf
tf.compat.v1.disable_eager_execution()
tf.compat.v1.enable_resource_variables()
def main():
meta_graph = f"{args.input}.meta"
if args.meta_graph:
meta_graph = args.meta_graph
logger.info(f"Load model [checkpoint: {args.input} meta: {meta_graph}]")
session = tf.compat.v1.Session()
saver = tf.compat.v1.train.import_meta_graph(meta_graph)
session.run(tf.compat.v1.global_variables_initializer())
saver.restore(session, args.input)
placeholders = [
f"{x.name}:0" for x in session.graph.get_operations() if x.type == "Placeholder" or x.type == "PlaceholderV2"
]
logger.info(f"Use inputs: {placeholders}")
logger.info(f"Use outputs: {args.output_names.split(',')}")
placeholder_map = {}
for k in placeholders:
placeholder_map[k] = session.graph.get_tensor_by_name(k)
output_map = {}
for k in args.output_names.split(","):
output_map[k] = session.graph.get_tensor_by_name(k)
inputs = {k: tf.compat.v1.saved_model.build_tensor_info(x) for k, x in placeholder_map.items()}
outputs = {k: tf.compat.v1.saved_model.build_tensor_info(x) for k, x in output_map.items()}
signatures = tf.compat.v1.saved_model.build_signature_def(inputs=inputs, outputs=outputs)
builder = tf.compat.v1.saved_model.Builder(args.output)
builder.add_meta_graph_and_variables(
session,
[tf.compat.v1.saved_model.tag_constants.SERVING],
signature_def_map={tf.compat.v1.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signatures},
)
builder.save(as_text=False)
logger.info(f"Written SavedModel to {args.output}")
if __name__ == "__main__":
main()
Before using this script, you should call tools/compile_tf_graph.py
on your returnn
config file and generate a .meta
file.
Assuming the resulting meta graph is saved to $META_GRAPH_PATH
and you named the script checkpoint2sm.py
, example usage with default output tensor name output/output_batch_major:0
:
python3 checkpoint2sm.py -i $CHECKPOINT_PATH -o model-sm -m $META_GRAPH_PATH
Beware that during tf2onnx
conversion when using SavedModel
format, the input placeholders' names containing /
and :
are replaced by _
. So you need to account for that in ONNX
inference. You can check the output of tf2onnx
for the final output names.
Actually, this happens when you load the SavedModel
in TensorFlow.
Once you have your SavedModel
, you now would use
python3 -m tf2onnx.convert --saved-model model-sm/ --output model.onnx
Alternatively, you could use the tf2onnx
API directly which supports loading ,for example, from a graph def. This is what we do internally.
Model Conversion
LSTM models
Converting returnn
custom ops, most relevant NativeLstm2
, directly is not supported. We only have support for this internally. This requires to write custom graph rewriters and register those in tf2onnx
. Beware that NativeLstm2
doesn't do input projection inside the op, so you would have to match that sub-graph too.
One work around is to use a different LSTM kernel, for example vanillalstm
, which does not rely on custom ops. However, for vanillalstm
as a returnn
implementation, some kernels do symbolic loops and then are not converted to the fused ONNX
LSTM op because there is no pattern matching the corresponding sub-graph in tf2onnx
.
There are some LSTM ops tf2onnx
supports such as tf.keras.layers.LSTM
but there is no support for it in returnn
.
For TensorFlow 1.15 and below returnn
supports CudnnLSTM
, LSTMBlockFused
etc., which tf2onnx
also supports and rewrites into the ONNX
LSTM op.
This is not be needed when using vanillalstm
because it is compatible with the NativeLstm2
weights layout. But then you do not get to use the ONNX
fused LSTM op.
For completeness, I provide an example using LSTMBlock
in TensorFlow 1.15
uisng the network-dict frontend in returnn
.
However, i can't recommend using such an ancient version for this. This is just for documentation purpose.
If at any given time returnn
would support the Keras
LSTMs, then you could apply the same principle of course.
Example for TensorFlow 1.15 - Expand Me
You would change the unit of your LSTM layer and then you have to convert the parameters from `NativeLstm2` to the respective format expected by the different kernel by using `custom_param_importer` for the layer. Adapting the config to something like this for `LSTMBlock` works:def import_lstm_params_from_nativelstm2(layer, values_dict, session):
import numpy as np
checkpoint_param_name_w = f"W"
checkpoint_param_name_w_rec = f"W_re"
checkpoint_param_name_b = f"b" # input projection bias
param_name = "lstm_cell/kernel"
param_name_b = "lstm_cell/bias"
assert param_name in layer.params, f"{layer}: param {param_name} unknown"
assert param_name_b in layer.params, f"{layer}: param {param_mame_b} unknown"
# have: cifo need: icfo
bc, bi, bf, bo = np.split(values_dict[checkpoint_param_name_b], 4, axis=0)
new_bias = np.concatenate([bi, bc, bf, bo], axis=0)
param_b = layer.params[param_name_b]
assert isinstance(param_b, tf.Variable)
layer.network.get_var_assigner(param_b).assign(new_bias, session=session)
wc, wi, wf, wo = np.split(values_dict[checkpoint_param_name_w], 4, axis=1)
w = np.concatenate([wi, wc, wf, wo], axis=1)
wc_rec, wi_rec, wf_rec, wo_rec = np.split(values_dict[checkpoint_param_name_w_rec], 4, axis=1)
w_rec = np.concatenate([wi_rec, wc_rec, wf_rec, wo_rec], axis=1)
new_kernel = np.concatenate([w, w_rec], axis=0)
param_kernel = layer.params[param_name]
assert isinstance(param_kernel, tf.Variable)
layer.network.get_var_assigner(param_kernel).assign(new_kernel, session=session)
[...] SOME RETURNN CONFIG
model = "$OUTPUT_PATH_OF_CONVERTED_MODEL"
load = "$CHECKPOINT_PATH_TO_CONVERT"
task = "initialize_model"
In each LSTM layer, you would then add/change the options:
"unit": "lstmblock",
"custom_param_importer": import_lstm_params_from_nativelstm2,
"unit_opts": {"forget_bias": 0.0},
And then you would call
python3 returnn/rnn.py $CONFIG_FILE
assuming you saved the adapted config file to the path $CONFIG_FILE
.
Then call tools/compile_tf_graph.py
on the $CONFIG_FILE
and convert to ONNX
, same as above.
Differences
To give you an idea of the runtime difference between using a symbolic loop and the fused LSTM ONNX
op, here are some timings for perplexity calculation on a small evaluation dataset for a 2-layer LSTM language model using full softmax computation with vocabulary size > 200k
on CPU. For quantization we use 8-bit integers and is done dynamically. All models yield the same perplexities.
Model | Time (sec.) | Latency (sec.) |
---|---|---|
Vanilla | 56.42 | 0.11 |
Vanilla Quantized | 51.39 | 0.10 |
ONNX LSTM | 55.31 | 0.11 |
ONNX LSTM Quantized | 38.89 | 0.08 |
As you can see, the runtime difference for the non-quantized models is negligible. However, for the quantized versions the difference is huge.
We successfully converted BiLSTM acoustic models and LSTM language models to ONNX
.
Self-Attention based models
Self-Attention based models should work in returnn
with small modifications to the returnn
config prior to using tools/compile_tf_graph.py
. Specifically:
inf_value = 1e10 # avoids NaNs in the outputs when doing masking followed by a softmax
onnx_export = True # to be specific: only needed when using CNN layers and doing sub-sampling or so because of a bug in tf2onnx but doesn't hurt
Exporting self-attention networks only works with the current tf2onnx
github upstream for now. The PyPi release package (1.15.1) does not have necessary fixes for it yet. So you would do
pip install git+https://github.com/onnx/tensorflow-onnx
So far we have successfully converted self-attention based models such as Conformer
and other Transformer
based acoustic models and Transformer
language models.
State Management
In TensorFlow
, the variables can carry values between session.run
calls; you can assign values to them etc.
In ONNXRuntime
there is no persistent state. So in order to allow state management, all initial states need to be made a placeholder
as additional input and the final states need to be additional outputs in the ONNX model.
For that, you need to adapt your self-attention or LSTM layers and add/replace
"initial_state": "placeholder",
Now the initial states are placeholders and the final states also have readable names now.
This enables you to pass the hidden states to the onnxruntime
session call and retrieve the next hidden states via the outputs.
For example, we added a mapping from placeholders to corresponding final state tensor names in the metadata of the ONNX model and then have the convention that the first output is the neural network output and the remaining outputs are final states in the order appearing in the metadata.
For LSTMs such mapping could look like this
mapping = {
"lstm0/rec/initial_state_placeholder_h:0": "lstm0/rec/last_state_h:0",
"lstm0/rec/initial_state_placeholder_c:0": "lstm0/rec/last_state_h:0",
[...]
}
Or self attention using the standard i6 transformer implementation
mapping = {
"output/rec/dec_0_self_att_att/initial_state_placeholder_k_left:0": "output/rec/dec_0_self_att_att/last_state_k_left:0",
"output/rec/dec_0_self_att_att/initial_state_placeholder_v_left:0": "output/rec/dec_0_self_att_att/last_state_v_left:0",
[...]
}
Of course you have to adapt this to your case.
Model Optimizations
It is not possible to use Transformer
-based model optimization such as using onnxruntime transformers. This is because the patterns used do not match the returnn
implementation.
What we did internally is to write custom graph rewrites, similar to NativeLstm2
case, and apply the optimizations manually this way. This results in decent speedup; for example using onnxruntime
custom MultiHeadAttention
op instead of the returnn
SelfAttention
sub-graph.
Every other model should work out of the box. If you encounter issues or so, feel free to ask specifics.
Hope this answers your question. Feel free to let me know if you have questions.
Thanks for the detailed write-up. Maybe you can add this to our documentation or RETURNN wiki?
Python Script - Expand Me
...
What we did internally is to write custom graph rewrites, similar to
NativeLstm2
case, and apply the optimizations manually this way. This results in decent speedup; for example usingonnxruntime
customMultiHeadAttention
op instead of thereturnn
SelfAttention
sub-graph.
Maybe you can put those scripts also to tools
or so?
Thanks for the detailed write-up. Maybe you can add this to our documentation or RETURNN wiki?
Yes, i can do that
Maybe you can put those scripts also to tools or so?
I can add the checkpoint-to-savedmodel
script in the next days.
As for the rewriters, it takes a little bit longer. Parts of it are not even integrated internally yet.
It would be easier to actually integrate it into the i6_core
repository at some point together with the tensorflow-to-onnx
job we use so we don't have to manage the same code twice and is independent of the returnn
version we use.
It would be easier to actually integrate it into the
i6_core
repository at some point together with thetensorflow-to-onnx
job we use so we don't have to manage the same code twice and is independent of thereturnn
version we use.
It depends always. As soon as there is anything non-trivial, slightly complex logic, even if independent of RETURNN, we would put it in tools
, and then the job in i6_core would just call this script.
So in principle a user could also run the tool directly and not use Sisyphus / i6_core.
If it's so trivial, basically anyway just one single command to run, which is also just another external script, this of course would not justify an own script in tools
.
But I think such rewriting thing sounds very much like it should be an own script? And I think RETURNN tools
is the optimal place for this.