tensorflow-onnx icon indicating copy to clipboard operation
tensorflow-onnx copied to clipboard

ValueError: make_sure failure: Only keep_num_dims=False supported for fully connected op

Open josephrocca opened this issue 2 years ago • 3 comments

Describe the bug The error in the title is produced when converting a tflite file to ONNX via tf2onnx.convert. The tflite file was was produced by converting a JAX function to tflite via tf.lite.TFLiteConverter.experimental_from_jax as shown in the Colab notebook linked below.

Urgency No hard deadline.

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 18.04 (Google Colab)
  • Tensorflow Version: 2.8.0
  • Python version: 3.7.13

To Reproduce I've uploaded the model here, but it's 900mb, so it's probably much faster to give you the notebook that generates the model.

Click Runtime > Run all in this Colab notebook to reproduce:

https://colab.research.google.com/drive/1DygMV-Nlae6BEJmZjN_laIdfYL33BbE0

Summary of what the Colab notebook does to generate the model (in case it's helpful at all):

  1. Load Flax CLIP
  2. Create a score function that uses the CLIP model inside it
  3. Get the gradient of that score function with jax.grad(score)
  4. Convert that jax.grad(score) function to tflite using tf.lite.TFLiteConverter.experimental_from_jax
  5. Check that the tflite output matches the jax output
  6. Convert that tflite file to ONNX with python -m tf2onnx.convert (which produces the error)

Note that another potential route from JAX to ONNX is via jax2tf, but I found that the tensorflow function that's produced by jax2tf can't be converted to ONNX due to tf2onnx's lack of support for PartitionedCall. I thought that perhaps the tflite to ONNX route via experimental_from_jax and tf2onnx might work, but ran into this error.

The error log starts off like this:

/usr/lib/python3.7/runpy.py:125: RuntimeWarning: 'tf2onnx.convert' found in sys.modules after import of package 'tf2onnx', but prior to execution of 'tf2onnx.convert'; this may result in unpredictable behaviour
  warn(RuntimeWarning(msg))
2022-05-16 18:18:30,858 - INFO - Using tensorflow=2.8.0, onnx=1.11.0, tf2onnx=1.10.1/a37f29
2022-05-16 18:18:30,858 - INFO - Using opset <onnx, 13>
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
2022-05-16 18:18:42,690 - ERROR - Failed to convert node 'xla_computation(score)/jit(main)/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None];147' (fct=<bound method TflFullyConnectedOp.to_tf of <class 'tf2onnx.tflite_handlers.tfl_math.TflFullyConnectedOp'>>)
'OP=TFL_FULLY_CONNECTED\nName=xla_computation(score)/jit(main)/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None];147\nInputs:\n\txla_computation(score)/jit(main)/add;199=Add, [1, 50, 768], 1\n\txla_computation(score)/jit(main)/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None];=Const, [768, 768], 1\nOutpus:\n\txla_computation(score)/jit(main)/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None];147=[1, 50, 768], 1'
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/tf2onnx/tfonnx.py", line 292, in tensorflow_onnx_mapping
    func(g, node, **kwargs, initialized_tables=initialized_tables, dequantize=dequantize)
  File "/usr/local/lib/python3.7/dist-packages/tf2onnx/tflite_handlers/tfl_math.py", line 205, in to_tf
    "Only keep_num_dims=False supported for fully connected op")
  File "/usr/local/lib/python3.7/dist-packages/tf2onnx/utils.py", line 264, in make_sure
    raise ValueError("make_sure failure: " + error_msg % args)
ValueError: make_sure failure: Only keep_num_dims=False supported for fully connected op

It then repeats that same ValueError many times and finishes with this:

2022-05-16 18:18:44,763 - ERROR - Failed to convert node 'xla_computation(score)/jit(main)/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None];139' (fct=<bound method TflFullyConnectedOp.to_tf of <class 'tf2onnx.tflite_handlers.tfl_math.TflFullyConnectedOp'>>)
'OP=TFL_FULLY_CONNECTED\nName=xla_computation(score)/jit(main)/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None];139\nInputs:\n\txla_computation(score)/jit(main)/add_any;114=Add, [1, 50, 3072], 1\n\txla_computation(score)/jit(main)/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None];67=Const, [768, 3072], 1\nOutpus:\n\txla_computation(score)/jit(main)/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=None];139=[1, 50, 768], 1'
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/tf2onnx/tfonnx.py", line 292, in tensorflow_onnx_mapping
    func(g, node, **kwargs, initialized_tables=initialized_tables, dequantize=dequantize)
  File "/usr/local/lib/python3.7/dist-packages/tf2onnx/tflite_handlers/tfl_math.py", line 205, in to_tf
    "Only keep_num_dims=False supported for fully connected op")
  File "/usr/local/lib/python3.7/dist-packages/tf2onnx/utils.py", line 264, in make_sure
    raise ValueError("make_sure failure: " + error_msg % args)
ValueError: make_sure failure: Only keep_num_dims=False supported for fully connected op
Traceback (most recent call last):
  File "/usr/lib/python3.7/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/usr/lib/python3.7/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/usr/local/lib/python3.7/dist-packages/tf2onnx/convert.py", line 640, in <module>
    main()
  File "/usr/local/lib/python3.7/dist-packages/tf2onnx/convert.py", line 287, in main
    output_path=args.output)
  File "/usr/local/lib/python3.7/dist-packages/tf2onnx/convert.py", line 162, in _convert_common
    custom_op_handlers=custom_op_handlers, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/tf2onnx/tfonnx.py", line 439, in process_tf_graph
    initialized_tables, tensors_to_rename, is_tflite, dequantize)
  File "/usr/local/lib/python3.7/dist-packages/tf2onnx/tfonnx.py", line 492, in process_graphs
    dequantize)
  File "/usr/local/lib/python3.7/dist-packages/tf2onnx/tfonnx.py", line 512, in process_parsed_graph
    raise exceptions[0]
  File "/usr/local/lib/python3.7/dist-packages/tf2onnx/tfonnx.py", line 292, in tensorflow_onnx_mapping
    func(g, node, **kwargs, initialized_tables=initialized_tables, dequantize=dequantize)
  File "/usr/local/lib/python3.7/dist-packages/tf2onnx/tflite_handlers/tfl_math.py", line 205, in to_tf
    "Only keep_num_dims=False supported for fully connected op")
  File "/usr/local/lib/python3.7/dist-packages/tf2onnx/utils.py", line 264, in make_sure
    raise ValueError("make_sure failure: " + error_msg % args)
ValueError: make_sure failure: Only keep_num_dims=False supported for fully connected op

josephrocca avatar May 16 '22 18:05 josephrocca

Have you solved the problem now?

pyl62112991 avatar Oct 10 '23 02:10 pyl62112991

same problem

buptlj avatar Dec 08 '23 07:12 buptlj

To anyone that finds this later, based on the docs of TFLite, keep_num_dims is probably turned on when the input shape of the Fully Connected layer is higher than 2 dimensions. Not sure if that will help, but that's that.

Doomski99 avatar Jan 23 '24 16:01 Doomski99