tensorflow-onnx
tensorflow-onnx copied to clipboard
ValueError: make_sure failure: Only keep_num_dims=False supported for fully connected op
Describe the bug
The error in the title is produced when converting a tflite file to ONNX via tf2onnx.convert
. The tflite file was was produced by converting a JAX function to tflite via tf.lite.TFLiteConverter.experimental_from_jax
as shown in the Colab notebook linked below.
Urgency No hard deadline.
System information
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 18.04 (Google Colab)
- Tensorflow Version: 2.8.0
- Python version: 3.7.13
To Reproduce I've uploaded the model here, but it's 900mb, so it's probably much faster to give you the notebook that generates the model.
Click Runtime > Run all
in this Colab notebook to reproduce:
https://colab.research.google.com/drive/1DygMV-Nlae6BEJmZjN_laIdfYL33BbE0
Summary of what the Colab notebook does to generate the model (in case it's helpful at all):
- Load Flax CLIP
- Create a
score
function that uses the CLIP model inside it - Get the gradient of that
score
function withjax.grad(score)
- Convert that
jax.grad(score)
function to tflite usingtf.lite.TFLiteConverter.experimental_from_jax
- Check that the tflite output matches the jax output
- Convert that tflite file to ONNX with
python -m tf2onnx.convert
(which produces the error)
Note that another potential route from JAX to ONNX is via jax2tf, but I found that the tensorflow function that's produced by jax2tf can't be converted to ONNX due to tf2onnx's lack of support for PartitionedCall. I thought that perhaps the tflite to ONNX route via experimental_from_jax and tf2onnx might work, but ran into this error.
The error log starts off like this:
/usr/lib/python3.7/runpy.py:125: RuntimeWarning: 'tf2onnx.convert' found in sys.modules after import of package 'tf2onnx', but prior to execution of 'tf2onnx.convert'; this may result in unpredictable behaviour
warn(RuntimeWarning(msg))
2022-05-16 18:18:30,858 - INFO - Using tensorflow=2.8.0, onnx=1.11.0, tf2onnx=1.10.1/a37f29
2022-05-16 18:18:30,858 - INFO - Using opset <onnx, 13>
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
2022-05-16 18:18:42,690 - ERROR - Failed to convert node 'xla_computation(score)/jit(main)/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None];147' (fct=<bound method TflFullyConnectedOp.to_tf of <class 'tf2onnx.tflite_handlers.tfl_math.TflFullyConnectedOp'>>)
'OP=TFL_FULLY_CONNECTED\nName=xla_computation(score)/jit(main)/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None];147\nInputs:\n\txla_computation(score)/jit(main)/add;199=Add, [1, 50, 768], 1\n\txla_computation(score)/jit(main)/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None];=Const, [768, 768], 1\nOutpus:\n\txla_computation(score)/jit(main)/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None];147=[1, 50, 768], 1'
Traceback (most recent call last):
File "/usr/local/lib/python3.7/dist-packages/tf2onnx/tfonnx.py", line 292, in tensorflow_onnx_mapping
func(g, node, **kwargs, initialized_tables=initialized_tables, dequantize=dequantize)
File "/usr/local/lib/python3.7/dist-packages/tf2onnx/tflite_handlers/tfl_math.py", line 205, in to_tf
"Only keep_num_dims=False supported for fully connected op")
File "/usr/local/lib/python3.7/dist-packages/tf2onnx/utils.py", line 264, in make_sure
raise ValueError("make_sure failure: " + error_msg % args)
ValueError: make_sure failure: Only keep_num_dims=False supported for fully connected op
It then repeats that same ValueError many times and finishes with this:
2022-05-16 18:18:44,763 - ERROR - Failed to convert node 'xla_computation(score)/jit(main)/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None];139' (fct=<bound method TflFullyConnectedOp.to_tf of <class 'tf2onnx.tflite_handlers.tfl_math.TflFullyConnectedOp'>>)
'OP=TFL_FULLY_CONNECTED\nName=xla_computation(score)/jit(main)/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None];139\nInputs:\n\txla_computation(score)/jit(main)/add_any;114=Add, [1, 50, 3072], 1\n\txla_computation(score)/jit(main)/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None];67=Const, [768, 3072], 1\nOutpus:\n\txla_computation(score)/jit(main)/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None];139=[1, 50, 768], 1'
Traceback (most recent call last):
File "/usr/local/lib/python3.7/dist-packages/tf2onnx/tfonnx.py", line 292, in tensorflow_onnx_mapping
func(g, node, **kwargs, initialized_tables=initialized_tables, dequantize=dequantize)
File "/usr/local/lib/python3.7/dist-packages/tf2onnx/tflite_handlers/tfl_math.py", line 205, in to_tf
"Only keep_num_dims=False supported for fully connected op")
File "/usr/local/lib/python3.7/dist-packages/tf2onnx/utils.py", line 264, in make_sure
raise ValueError("make_sure failure: " + error_msg % args)
ValueError: make_sure failure: Only keep_num_dims=False supported for fully connected op
Traceback (most recent call last):
File "/usr/lib/python3.7/runpy.py", line 193, in _run_module_as_main
"__main__", mod_spec)
File "/usr/lib/python3.7/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/usr/local/lib/python3.7/dist-packages/tf2onnx/convert.py", line 640, in <module>
main()
File "/usr/local/lib/python3.7/dist-packages/tf2onnx/convert.py", line 287, in main
output_path=args.output)
File "/usr/local/lib/python3.7/dist-packages/tf2onnx/convert.py", line 162, in _convert_common
custom_op_handlers=custom_op_handlers, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/tf2onnx/tfonnx.py", line 439, in process_tf_graph
initialized_tables, tensors_to_rename, is_tflite, dequantize)
File "/usr/local/lib/python3.7/dist-packages/tf2onnx/tfonnx.py", line 492, in process_graphs
dequantize)
File "/usr/local/lib/python3.7/dist-packages/tf2onnx/tfonnx.py", line 512, in process_parsed_graph
raise exceptions[0]
File "/usr/local/lib/python3.7/dist-packages/tf2onnx/tfonnx.py", line 292, in tensorflow_onnx_mapping
func(g, node, **kwargs, initialized_tables=initialized_tables, dequantize=dequantize)
File "/usr/local/lib/python3.7/dist-packages/tf2onnx/tflite_handlers/tfl_math.py", line 205, in to_tf
"Only keep_num_dims=False supported for fully connected op")
File "/usr/local/lib/python3.7/dist-packages/tf2onnx/utils.py", line 264, in make_sure
raise ValueError("make_sure failure: " + error_msg % args)
ValueError: make_sure failure: Only keep_num_dims=False supported for fully connected op
Have you solved the problem now?
same problem
To anyone that finds this later, based on the docs of TFLite, keep_num_dims
is probably turned on when the input shape of the Fully Connected layer is higher than 2 dimensions. Not sure if that will help, but that's that.