tensorflow-onnx icon indicating copy to clipboard operation
tensorflow-onnx copied to clipboard

tf2onns failed : google.protobuf.message.DecodeError: Error parsing message with type 'onnx.AttributeProto'

Open wushandinghua opened this issue 3 months ago • 0 comments

Describe the bug

I have inference function and params of a jax model and convert it to a tf saved model. I encounter a issue when i convert the saved model to onnx model.How can i solve it? tf2onnx issue:

<frozen runpy>:128: RuntimeWarning: 'tf2onnx.convert' found in sys.modules after import of package 'tf2onnx', but prior to execution of 'tf2onnx.convert'; this may result in unpredictable behaviour
2025-08-27 18:19:57,445 - WARNING - tf2onnx.tf_loader: '--tag' not specified for saved_model. Using --tag serve
2025-08-27 18:20:01,172 - INFO - tf2onnx.tf_loader: Signatures found in model: [serving_default].
2025-08-27 18:20:01,172 - WARNING - tf2onnx.tf_loader: '--signature_def' not specified, using first signature: serving_default
2025-08-27 18:20:01,172 - INFO - tf2onnx.tf_loader: Output names: ['output_0']
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1756290001.225592    5292 devices.cc:76] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 0 (Note: TensorFlow was not compiled with CUDA or ROCm support)
I0000 00:00:1756290001.225854    5292 single_machine.cc:376] Starting new session
2025-08-27 18:20:11,084 - INFO - tf2onnx: inputs: ['inputs_0:0', 'inputs_1:0', 'inputs_2:0', 'inputs_3:0', 'inputs_4:0', 'inputs_5:0', 'inputs_6:0']
2025-08-27 18:20:11,084 - INFO - tf2onnx: outputs: ['Identity:0']
2025-08-27 18:20:14,362 - INFO - tf2onnx.tfonnx: Using tensorflow=2.20.0, onnx=1.17.0, tf2onnx=1.16.1/15c810
2025-08-27 18:20:14,363 - INFO - tf2onnx.tfonnx: Using opset <onnx, 21>
2025-08-27 18:20:20,789 - ERROR - tf2onnx.tf_utils: pass1 convert failed for name: "unknown_43"
op: "Const"
attr {
  key: "value"
  value {
    tensor {
      dtype: DT_HALF
      tensor_shape {
        dim {
          size: 18
        }
        dim {
          size: 2
        }
        dim {
          size: 2048
        }
        dim {
          size: 16384
        }
      }
    }
  }
}
attr {
  key: "dtype"
  value {
    type: DT_HALF
  }
}
, ex=Error parsing message with type 'onnx.AttributeProto'
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/nvidia/users/qbb/workspace/projects/openpi/.venv/lib/python3.11/site-packages/tf2onnx/convert.py", line 714, in <module>
    main()
  File "/home/nvidia/users/qbb/workspace/projects/openpi/.venv/lib/python3.11/site-packages/tf2onnx/convert.py", line 273, in main
    model_proto, _ = _convert_common(
                     ^^^^^^^^^^^^^^^^
  File "/home/nvidia/users/qbb/workspace/projects/openpi/.venv/lib/python3.11/site-packages/tf2onnx/convert.py", line 168, in _convert_common
    g = process_tf_graph(tf_graph, const_node_values=const_node_values,
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nvidia/users/qbb/workspace/projects/openpi/.venv/lib/python3.11/site-packages/tf2onnx/tfonnx.py", line 459, in process_tf_graph
    main_g, subgraphs = graphs_from_tf(tf_graph, input_names, output_names, shape_override, const_node_values,
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nvidia/users/qbb/workspace/projects/openpi/.venv/lib/python3.11/site-packages/tf2onnx/tfonnx.py", line 474, in graphs_from_tf
    ordered_func = resolve_functions(tf_graph)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nvidia/users/qbb/workspace/projects/openpi/.venv/lib/python3.11/site-packages/tf2onnx/tf_loader.py", line 784, in resolve_functions
    _, _, _, _, _, functions = tflist_to_onnx(tf_graph, {})
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nvidia/users/qbb/workspace/projects/openpi/.venv/lib/python3.11/site-packages/tf2onnx/tf_utils.py", line 463, in tflist_to_onnx
    onnx_node = utils.make_onnx_node_with_attr(node_type, input_names, output_names, name=node.name, **attr)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nvidia/users/qbb/workspace/projects/openpi/.venv/lib/python3.11/site-packages/tf2onnx/utils.py", line 207, in make_onnx_node_with_attr
    onnx_node = helper.make_node(op_type, inputs, outputs, name=name, domain=domain, **valid_attrs)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nvidia/users/qbb/workspace/projects/openpi/.venv/lib/python3.11/site-packages/onnx/helper.py", line 175, in make_node
    node.attribute.extend(
  File "/home/nvidia/users/qbb/workspace/projects/openpi/.venv/lib/python3.11/site-packages/onnx/helper.py", line 175, in <genexpr>
    node.attribute.extend(
                         ^
google.protobuf.message.DecodeError: Error parsing message with type 'onnx.AttributeProto'

tf2onnx scripts:

python -m tf2onnx.convert --saved-model /dev/shm/tmp/tf_model --output /dev/shm/pi0_galaxea_lora.onnx --opset 21 --large_model --verbose

convert jax to tf saved model scripts:

def jax2tf_saved_model(inference_fn, params, save_path, batch_size, action_dim, max_token_len):
    """Convert JAX function to TensorFlow and then to ONNX."""
    # This function is not used in the final export, but can be useful for debugging.
    def extract_value(p):
        if isinstance(p, (dict, nnx.State)):
            return {k: extract_value(v) for k, v in p.items()}
        elif isinstance(p, nnx.variablelib.VariableState):
            return p.value
        return p
            

    params_plain = extract_value(params)
    
    # print("params_plain:", params_plain)
    print("get value finished")

    def to_tf_variable(x):
        if isinstance(x, (float, int, bool, list, tuple)):
            return tf.Variable(x)
        elif isinstance(x, dict):
            return {k: to_tf_variable(v) for k, v in x.items()}
        elif isinstance(x, (jax.Array)):
            return tf.Variable(tf.convert_to_tensor(np.asarray(x, copy=False)))
        return x
    # params_vars = to_tf_variable(params_plain)
    params_vars = tf.nest.map_structure(tf.Variable, params_plain)
    del params_plain
    print(params_vars)
    print("to tf variable finished")
    
    input_specs = [
        tf.TensorSpec([2], tf.uint32),  # rng
        tf.TensorSpec([batch_size, 480, 640, 3], tf.float32),  # base image
        tf.TensorSpec([batch_size, 480, 640, 3], tf.float32),  # left image
        tf.TensorSpec([batch_size, 480, 640, 3], tf.float32),  # right image
        tf.TensorSpec([batch_size, action_dim], tf.float32),  # state
        tf.TensorSpec([batch_size, max_token_len], tf.int32),  # tokens
        tf.TensorSpec([batch_size, max_token_len], tf.bool),  # token mask
    ]
    my_model = tf.Module()
    my_model._variables = tf.nest.flatten(params_vars)
    prediction_tf = lambda *inputs: jax2tf.convert(inference_fn, native_serialization=False, with_gradient=False)(params_vars, *inputs)
    my_model.f = tf.function(prediction_tf, jit_compile=True, autograph=False, input_signature=input_specs)
    tf.saved_model.save(my_model, f'{save_path}/tf_model', options=tf.saved_model.SaveOptions(experimental_custom_gradients=True))

Urgency

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 18.04*):nvidia jetpack 6.1
  • TensorFlow Version:2.20
  • Python version:3.11
  • ONNX version (if applicable, e.g. 1.11*):1.17.0
  • ONNXRuntime version (if applicable, e.g. 1.11*):none

To Reproduce

Screenshots

Additional context

wushandinghua avatar Aug 27 '25 10:08 wushandinghua