tensorflow-onnx icon indicating copy to clipboard operation
tensorflow-onnx copied to clipboard

Add masked LSTM support

Open q-ycong-p opened this issue 2 years ago • 3 comments

Masking is a intra-layer behavior in TF LSTM [1] but is not a intra-op behavior in ONNX LSTM [2]. When converted to ONNX, masked TF LSTM layer is converted to Loop op. This over-complicates the ONNX model, and has a negative impact on inference performance in ORT without leveraging LSTM optimizations. (issue #1871)

This commit adds support to convert masked LSTM correctly, under the important assumption that input must be post-padded - which is the most common use case. The "masking" info is conveyed to ONNX LSTM op as sequence_lens which is dynamically computed by summing the number of non-skip timesteps per batch per-LSTM. This behavior is implemented with reference to keras2onnx PR#386 [3]. Additional logic is added for backward LSTM so that the input sequence is reversed correctly given sequence_lens.

Note that if mask-enabled, and LSTM input is pre- or randomly padded, the converted ONNX model will behave incorrectly for inference. Unless ONNX add new attribute e.g. mask_enabled to RNN ops, converter alone may not be able to handle generic masking while keeping the RNN ops, since masking alters intra-op behavior. With such limitation, I'd like to share this PR for further comment and suggestion.

[1] https://www.tensorflow.org/guide/keras/masking_and_padding#masking [2] https://github.com/onnx/onnx/blob/main/docs/Operators.md#LSTM [3] https://github.com/onnx/keras-onnx/pull/386


Details:

Forward LSTM

Here's an minimal example with an embedded LSTM (mask_zeros=True):

  • H5 model: Screen Shot 2022-08-26 at 6 18 13 PM

  • tf2onnx-converted ONNX model, before proposed change: Screen Shot 2022-08-26 at 6 17 49 PM

  • tf2onnx-converted ONNX model, after proposed change: Screen Shot 2022-08-26 at 6 18 57 PM

Reverse LSTM

  • Need to alter tf.raw_op.ReverseV2->ReverseSequence behavior to reverse LSTM input correctly: reverse_masked_lstm

q-ycong-p avatar Aug 27 '22 01:08 q-ycong-p

Sorry will address the test failures on TF-2.9 soon.

q-ycong-p avatar Oct 17 '22 18:10 q-ycong-p

Hello, Is there any progress with this issue?

AndreyOrb avatar Aug 18 '23 17:08 AndreyOrb

Hi, Is there any update? Will the proposed code work if pulled?

AndreyOrb avatar Jan 02 '24 20:01 AndreyOrb