tensorflow-onnx
tensorflow-onnx copied to clipboard
Recent ml_dtypes release (0.3.0) removed ml_dtypes.float8_e4m3b11 breaking tf2onnx
Describe the bug
Title says it all, but here's a stack trace:
AttributeError: module 'ml_dtypes' has no attribute 'float8_e4m3b11'. Did you mean: 'float8_e4m3fn'?
Traceback (most recent call last):
File "<frozen runpy>", line 189, in _run_module_as_main
File "<frozen runpy>", line 112, in _get_module_details
File "<redacted path>/lib/python3.11/site-packages/tf2onnx/__init__.py", line 10, in <module>
from . import verbose_logging as logging
File "<redacted path>/lib/python3.11/site-packages/tf2onnx/verbose_logging.py", line 14, in <module>
import tensorflow as tf
File "<redacted path>/lib/python3.11/site-packages/tensorflow/__init__.py", line 38, in <module>
from tensorflow.python.tools import module_util as _module_util
File "<redacted path>/lib/python3.11/site-packages/tensorflow/python/__init__.py", line 42, in <module>
from tensorflow.python.saved_model import saved_model
File "<redacted path>/lib/python3.11/site-packages/tensorflow/python/saved_model/saved_model.py", line 20, in <module>
from tensorflow.python.saved_model import builder
File "<redacted path>/lib/python3.11/site-packages/tensorflow/python/saved_model/builder.py", line 23, in <module>
from tensorflow.python.saved_model.builder_impl import _SavedModelBuilder
File "<redacted path>/lib/python3.11/site-packages/tensorflow/python/saved_model/builder_impl.py", line 26, in <module>
from tensorflow.python.framework import dtypes
File "<redacted path>/lib/python3.11/site-packages/tensorflow/python/framework/dtypes.py", line 39, in <module>
_np_float8_e5m2 = pywrap_ml_dtypes.float8_e5m2()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: Unable to convert function return value to a Python type! The signature was
() -> handle
This error does not appear with ml_dtypes version 0.2.0.
Urgency
Seems urgent to me - folks are going to have to implement temporary workaroudns like pinning
System information
- OS Platform and Distribution: Linux Ubuntu 20.04
- TensorFlow Version: tensorflow-cpu-aws (2.14.0rc1)
- Python version: 3.11.1
- ONNX version: 1.14.1
- ONNXRuntime version (if applicable, e.g. 1.11*): N/A
To Reproduce
python -m tf2onnx.convert --tflite $model_path --output $onnx_name --opset 13
(I don't think it's worth the time to give you a minimum working example as the issue here is quite obvious given the stack trace and recently released new version of ml_dtypes
On second look, this might actually be something that needs addressed in tensorflow (or in my case, tensorflow-cpu-aws.
Also broken on jax-metal-0.0.4 per the example here: https://developer.apple.com/metal/jax/
python -c 'import jax; print(jax.numpy.arange(10))'
AttributeError: module 'ml_dtypes' has no attribute 'float8_e4m3b11'. Did you mean: 'float8_e4m3fn'?
This tool doesn't rely on ml_dtypes module. By the stack trace, it was called during importing TensorFlow, so it looks like not an issue of tf2onnx.