onnxconverter-common icon indicating copy to clipboard operation
onnxconverter-common copied to clipboard

Cannot convert model that uses external data to float16

Open adamreeve opened this issue 1 year ago • 6 comments

It appears that onnxconverter_common.float16.convert_float_to_float16 doesn't convert external data from float to float16.

If I convert a model that uses external data for parameter values, the conversion appears to work, but then inference fails when getting the external data.

Repro code:

import numpy as np
import onnx
from onnx import helper, numpy_helper
from onnx.onnx_pb import TensorProto
from onnxconverter_common import float16
import onnxruntime


f32_model_path = "f32.onnx"
f16_model_path = "f16.onnx"

X = helper.make_tensor_value_info("X", TensorProto.FLOAT, ["N", 1])

rng = np.random.default_rng(0)
num_rows = 1000
x_array = rng.normal(0.0, 10.0, num_rows).astype(np.float32).reshape((num_rows, 1))
x_tensor = numpy_helper.from_array(x_array, "X")

# Constant node with Tensor valued parameter
const_node = helper.make_node(
        "Constant",
        inputs=[],
        outputs=["X"],
        value=x_tensor)

graph_def = helper.make_graph(
        nodes=[const_node],
        name="test-model",
        inputs=[],
        outputs=[X])

opset_import = helper.make_opsetid("", 21)
model_def = helper.make_model(
        graph_def,
        opset_imports=[opset_import],
        producer_name="onnx-example")

onnx.save_model(
        model_def, f32_model_path,
        save_as_external_data=True,
        location='data',
        size_threshold=0,
        convert_attribute=True)

model = onnx.load(f32_model_path, load_external_data=False)
model_fp16 = float16.convert_float_to_float16(model, keep_io_types=True)
onnx.save(model_fp16, f16_model_path)

modelf16 = onnx.load(f16_model_path)

# Run inference
session = onnxruntime.InferenceSession(f16_model_path)
output = session.run(["X"], {})[0]
print(output)

Output:

2024-11-26 13:21:05.007257146 [E:onnxruntime:, inference_session.cc:2117 operator()] Exception during initialization: /onnxruntime_src/onnxruntime/core/optimizer/optimizer_execution_frame.cc:71 onnxruntime::OptimizerExecutionFrame::Info::Info(const std::vector<const onnxruntime::Node*>&, const onnxruntime::InitializedTensorSet&, const std::filesystem::__cxx11::path&, const onnxruntime::IExecutionProvider&, const std::function<bool(const std::__cxx11::basic_string<char>&)>&) [ONNXRuntimeError] : 1 : FAIL : tensorprotoutils.cc:189 GetExternalDataInfo TensorProto: graph_output_cast_0 external data size mismatch. Computed size: 2000, external_data.length: 4000

Traceback (most recent call last):
  File "/home/adam/dev/gross/onnx-issues/float16-external-data/minimal_repro.py", line 63, in <module>
    session = onnxruntime.InferenceSession(f16_model_path)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/adam/dev/virtualenvs/ml/lib64/python3.12/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 465, in __init__
    self._create_inference_session(providers, provider_options, disabled_optimizers)
  File "/home/adam/dev/virtualenvs/ml/lib64/python3.12/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 537, in _create_inference_session
    sess.initialize_session(providers, provider_options, disabled_optimizers)
onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Exception during initialization: /onnxruntime_src/onnxruntime/core/optimizer/optimizer_execution_frame.cc:71 onnxruntime::OptimizerExecutionFrame::Info::Info(const std::vector<const onnxruntime::Node*>&, const onnxruntime::InitializedTensorSet&, const std::filesystem::__cxx11::path&, const onnxruntime::IExecutionProvider&, const std::function<bool(const std::__cxx11::basic_string<char>&)>&) [ONNXRuntimeError] : 1 : FAIL : tensorprotoutils.cc:189 GetExternalDataInfo TensorProto: graph_output_cast_0 external data size mismatch. Computed size: 2000, external_data.length: 4000

When an initializer uses external data, I instead get an error calling convert_float_to_float16.

Repro code:

import numpy as np
import onnx
from onnx import helper, numpy_helper
from onnx.onnx_pb import TensorProto
from onnxconverter_common import float16
import onnxruntime


f32_model_path = "f32.onnx"
f16_model_path = "f16.onnx"

X = helper.make_tensor_value_info("X", TensorProto.FLOAT, ["N", 1])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, ["N", 1])

rng = np.random.default_rng(0)
num_rows = 1000
x_array = rng.normal(0.0, 10.0, num_rows).astype(np.float32).reshape((num_rows, 1))
x_tensor = numpy_helper.from_array(x_array, "X")

abs_node = helper.make_node(
        "Abs",
        inputs=["X"],
        outputs=["Y"])

graph_def = helper.make_graph(
        nodes=[abs_node],
        name="test-model",
        inputs=[],
        outputs=[Y],
        initializer=[x_tensor])

opset_import = helper.make_opsetid("", 21)
model_def = helper.make_model(
        graph_def,
        opset_imports=[opset_import],
        producer_name="onnx-example")

onnx.save_model(
        model_def, f32_model_path,
        save_as_external_data=True,
        location='data',
        size_threshold=0,
        convert_attribute=True)

model = onnx.load(f32_model_path, load_external_data=False)
model_fp16 = float16.convert_float_to_float16(model, keep_io_types=True)
onnx.save(model_fp16, f16_model_path)

modelf16 = onnx.load(f16_model_path)

# Run inference
session = onnxruntime.InferenceSession(f16_model_path)
output = session.run(["X"], {})[0]
print(output)

Output:

Traceback (most recent call last):
  File "/home/adam/dev/gross/onnx-issues/float16-external-data/initializer_minimal_repro.py", line 46, in <module>
    model_fp16 = float16.convert_float_to_float16(model, keep_io_types=True)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/adam/dev/virtualenvs/ml/lib64/python3.12/site-packages/onnxconverter_common/float16.py", line 250, in convert_float_to_float16
    value_info_list.append(make_value_info_from_tensor(n))
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/adam/dev/virtualenvs/ml/lib64/python3.12/site-packages/onnxconverter_common/float16.py", line 100, in make_value_info_from_tensor
    shape = numpy_helper.to_array(tensor).shape
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/adam/dev/virtualenvs/ml/lib64/python3.12/site-packages/onnx/numpy_helper.py", line 407, in to_array
    return _to_array(tensor, base_dir=base_dir)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/adam/dev/virtualenvs/ml/lib64/python3.12/site-packages/onnx/numpy_helper.py", line 292, in _to_array
    return np.frombuffer(raw_data, dtype=np_dtype).reshape(dims)  # type: ignore[no-any-return]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: cannot reshape array of size 2000 into shape (1000,1)

If I patch numpy.frombuffer to work around this, I get a similar error during inference to the first example.

Repro code:

from contextlib import contextmanager
import numpy as np
import onnx
from onnx import helper, numpy_helper
from onnx.onnx_pb import TensorProto
from onnxconverter_common import float16
import onnxruntime
from unittest.mock import patch


@contextmanager
def patch_frombuffer():
    prev_frombuffer = np.frombuffer
    def new_frombuffer(*args, **kwargs):
        kwargs["dtype"] = np.float32
        return prev_frombuffer(*args, **kwargs).astype(np.float16)

    with patch("numpy.frombuffer", new_frombuffer):
        yield


f32_model_path = "f32.onnx"
f16_model_path = "f16.onnx"

X = helper.make_tensor_value_info("X", TensorProto.FLOAT, ["N", 1])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, ["N", 1])

rng = np.random.default_rng(0)
num_rows = 1000
x_array = rng.normal(0.0, 10.0, num_rows).astype(np.float32).reshape((num_rows, 1))
x_tensor = numpy_helper.from_array(x_array, "X")

abs_node = helper.make_node(
        "Abs",
        inputs=["X"],
        outputs=["Y"])

graph_def = helper.make_graph(
        nodes=[abs_node],
        name="test-model",
        inputs=[],
        outputs=[Y],
        initializer=[x_tensor])

opset_import = helper.make_opsetid("", 21)
model_def = helper.make_model(
        graph_def,
        opset_imports=[opset_import],
        producer_name="onnx-example")

onnx.save_model(
        model_def, f32_model_path,
        save_as_external_data=True,
        location='data',
        size_threshold=0,
        convert_attribute=True)

model = onnx.load(f32_model_path, load_external_data=False)
with patch_frombuffer():
    model_fp16 = float16.convert_float_to_float16(model, keep_io_types=True)
onnx.save(model_fp16, f16_model_path)

modelf16 = onnx.load(f16_model_path)

# Run inference
session = onnxruntime.InferenceSession(f16_model_path)
output = session.run(["X"], {})[0]
print(output)

Output:

2024-11-26 14:02:52.745117577 [E:onnxruntime:, inference_session.cc:2117 operator()] Exception during initialization: /onnxruntime_src/onnxruntime/core/optimizer/optimizer_execution_frame.cc:71 onnxruntime::OptimizerExecutionFrame::Info::Info(const std::vector<const onnxruntime::Node*>&, const onnxruntime::InitializedTensorSet&, const std::filesystem::__cxx11::path&, const onnxruntime::IExecutionProvider&, const std::function<bool(const std::__cxx11::basic_string<char>&)>&) [ONNXRuntimeError] : 1 : FAIL : tensorprotoutils.cc:189 GetExternalDataInfo TensorProto: X external data size mismatch. Computed size: 2000, external_data.length: 4000

Traceback (most recent call last):
  File "/home/adam/dev/gross/onnx-issues/float16-external-data/initializer_with_patch.py", line 82, in <module>
    session = onnxruntime.InferenceSession(f16_model_path)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/adam/dev/virtualenvs/ml/lib64/python3.12/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 465, in __init__
    self._create_inference_session(providers, provider_options, disabled_optimizers)
  File "/home/adam/dev/virtualenvs/ml/lib64/python3.12/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 537, in _create_inference_session
    sess.initialize_session(providers, provider_options, disabled_optimizers)
onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Exception during initialization: /onnxruntime_src/onnxruntime/core/optimizer/optimizer_execution_frame.cc:71 onnxruntime::OptimizerExecutionFrame::Info::Info(const std::vector<const onnxruntime::Node*>&, const onnxruntime::InitializedTensorSet&, const std::filesystem::__cxx11::path&, const onnxruntime::IExecutionProvider&, const std::function<bool(const std::__cxx11::basic_string<char>&)>&) [ONNXRuntimeError] : 1 : FAIL : tensorprotoutils.cc:189 GetExternalDataInfo TensorProto: X external data size mismatch. Computed size: 2000, external_data.length: 4000

Versions:

>>> onnx.__version__
'1.17.0'
>>> onnxconverter_common.__version__
'1.14.0'
>>> onnxruntime.__version__
'1.20.1'

adamreeve avatar Nov 26 '24 01:11 adamreeve

Not a proper solution, but one workaround you could try is to load all external data into the model proto during onnx.load. The loaded model object should have all its external data already initialized in TensorProto.raw_data, so the converter might just treat it as a regular model?

LeoZDong avatar Jan 30 '25 23:01 LeoZDong

Yes that should work, but only if the model stays under the 2 GB size limit imposed by Protobuf.

adamreeve avatar Jan 31 '25 13:01 adamreeve

@LeoZDong I have the same issue, any solutions here?

Kotomi-Du avatar Mar 14 '25 17:03 Kotomi-Du

@LeoZDong I have the same issue, any solutions here?

Not really other than manually hacking the external data to be FP16. I think the converter should really support this though as it's a common enough use case...

LeoZDong avatar Mar 14 '25 22:03 LeoZDong

@LeoZDong I have the same issue, any solutions here?

Not really other than manually hacking the external data to be FP16. I think the converter should really support this though as it's a common enough use case...

Yeah, I think so. Any comment here? @yetingqiaqia

Kotomi-Du avatar Mar 14 '25 23:03 Kotomi-Du

Hello, I'm facing the same issue: conversion of big model is failing. Any solution for this would be appreciated.

rYm-A avatar Aug 05 '25 14:08 rYm-A