tensorflow-onnx
tensorflow-onnx copied to clipboard
Failure in onnx conversion from tflite
onnx support tensorflow ops TensorListReserve、TensorListGetItem、TensorListSetItem、TensorListStack、TensorListFromTensor、TensorListConcatV2、TensorListGather...
Error output:
import onnx>>> m = onnx.load('test.onnx')>>> onnx.checker.check_model(m) Traceback (most recent call last): File "
", line 1, in File "C:\Users\scmckay\AppData\Roaming\Python\Python310\site-packages\onnx\checker.py", line 119, in check_model C.check_model(protobuf_string, full_check) onnx.onnx_cpp2py_export.checker.ValidationError: No Op registered for TensorListReserve with domain_version of 17
The original model seemed to have a bi-directional LSTM but the conversion to onnx is a) not detecting that and using an ONNX LSTM node, and b) writing out an ONNX model that has invalid nodes with operators such as TensorListReserve and TensorListFromTensor without failing the conversion.
tflite_onnx.zip attaching the file.
This particular model can be exported with 2 changes that @TomWildenhain-Microsoft kindly figured out.
- Use the TF processing for TFL_While similar to TFL_IF by removing TflWhile.version_7 and adding a to_tf to re-map the attributes
@tfl_op(["TFL_WHILE"], tf_op="While")
class TflWhile:
@classmethod
def to_tf(cls, ctx, node, **kwargs):
node.attr["cond"] = node.attr["cond_subgraph_index"]
del node.attr["cond_subgraph_index"]
node.attr["body"] = node.attr["body_subgraph_index"]
del node.attr["body_subgraph_index"]
- Use
replace_input
instead ofreplace_all_inputs
in the TF While handling so that shared TensorListReserve node isn't removed prematurely.
diff --git a/tf2onnx/onnx_opset/controlflow.py b/tf2onnx/onnx_opset/controlflow.py
index b6f70ced..6c989530 100644
--- a/tf2onnx/onnx_opset/controlflow.py
+++ b/tf2onnx/onnx_opset/controlflow.py
@@ -494,9 +494,10 @@ class While:
ragged_scan_output_names.append(body_ragged_name)
ragged_scan_output_to_len[output_names[idx]] = external_ragged_name
continue
- ctx.remove_node(n.name)
+
# make the node output bad
- ctx.replace_all_inputs(n.output[0], "@@ALLOC") # ops=ctx.get_nodes()
+ ctx.replace_input(node, node.input[idx], "@@ALLOC", idx)
+
del body.inputs[idx]
del cond_graph.inputs[idx]
del tf_while_inputs[idx]
Not clear how/when the first change to use the TF While handling can be used though, or how to conditionally use the TF handler from a TFL handler as currently it's one or the other.
Additionally, the conversion isn't able to detect there was originally a bi-directional LSTM that could (in theory) be converted to a single ONNX LSTM node, so you end up with a Loop node for each direction which is going to be less efficient.
Hi @fatcat-z, is this fixed by #2123?
Hi @fatcat-z, is this fixed by #2123?
No, it is not. The above changes may only help on that particular model and is not general enough which caused some CI failures.