tensorflow-onnx icon indicating copy to clipboard operation
tensorflow-onnx copied to clipboard

Failure in onnx conversion from tflite

Open gouravchat opened this issue 2 years ago • 5 comments

onnx support tensorflow ops TensorListReserve、TensorListGetItem、TensorListSetItem、TensorListStack、TensorListFromTensor、TensorListConcatV2、TensorListGather...

Error output:

import onnx>>> m = onnx.load('test.onnx')>>> onnx.checker.check_model(m) Traceback (most recent call last):   File "", line 1, in File "C:\Users\scmckay\AppData\Roaming\Python\Python310\site-packages\onnx\checker.py", line 119, in check_model C.check_model(protobuf_string, full_check) onnx.onnx_cpp2py_export.checker.ValidationError: No Op registered for TensorListReserve with domain_version of 17

gouravchat avatar Jan 27 '23 04:01 gouravchat

The original model seemed to have a bi-directional LSTM but the conversion to onnx is a) not detecting that and using an ONNX LSTM node, and b) writing out an ONNX model that has invalid nodes with operators such as TensorListReserve and TensorListFromTensor without failing the conversion.

skottmckay avatar Jan 27 '23 04:01 skottmckay

tflite_onnx.zip attaching the file.

gouravchat avatar Jan 27 '23 05:01 gouravchat

This particular model can be exported with 2 changes that @TomWildenhain-Microsoft kindly figured out.

  1. Use the TF processing for TFL_While similar to TFL_IF by removing TflWhile.version_7 and adding a to_tf to re-map the attributes
@tfl_op(["TFL_WHILE"], tf_op="While")
class TflWhile:
    @classmethod
    def to_tf(cls, ctx, node, **kwargs):
        node.attr["cond"] = node.attr["cond_subgraph_index"]
        del node.attr["cond_subgraph_index"]
        node.attr["body"] = node.attr["body_subgraph_index"]
        del node.attr["body_subgraph_index"]
  1. Use replace_input instead of replace_all_inputs in the TF While handling so that shared TensorListReserve node isn't removed prematurely.
diff --git a/tf2onnx/onnx_opset/controlflow.py b/tf2onnx/onnx_opset/controlflow.py
index b6f70ced..6c989530 100644
--- a/tf2onnx/onnx_opset/controlflow.py
+++ b/tf2onnx/onnx_opset/controlflow.py
@@ -494,9 +494,10 @@ class While:
                 ragged_scan_output_names.append(body_ragged_name)
                 ragged_scan_output_to_len[output_names[idx]] = external_ragged_name
                 continue
-            ctx.remove_node(n.name)
+
             # make the node output bad
-            ctx.replace_all_inputs(n.output[0], "@@ALLOC")  # ops=ctx.get_nodes()
+            ctx.replace_input(node, node.input[idx], "@@ALLOC", idx)
+
             del body.inputs[idx]
             del cond_graph.inputs[idx]
             del tf_while_inputs[idx]

Not clear how/when the first change to use the TF While handling can be used though, or how to conditionally use the TF handler from a TFL handler as currently it's one or the other.

Additionally, the conversion isn't able to detect there was originally a bi-directional LSTM that could (in theory) be converted to a single ONNX LSTM node, so you end up with a Loop node for each direction which is going to be less efficient.

skottmckay avatar Jan 30 '23 00:01 skottmckay

Hi @fatcat-z, is this fixed by #2123?

natke avatar Feb 28 '23 22:02 natke

Hi @fatcat-z, is this fixed by #2123?

No, it is not. The above changes may only help on that particular model and is not general enough which caused some CI failures.

fatcat-z avatar Mar 01 '23 00:03 fatcat-z