Properly Support BFloat16
Describe the bug Currently TF BFloat16 data type maps to Float16:
https://github.com/onnx/tensorflow-onnx/blob/main/tf2onnx/tf_utils.py#L31
and this is seen in onnx graphs generated by tf2onnx, where e.g. cast to bfloat16 and cast to float16 both result in elem_type: 10, which is float16, as seen here:
https://github.com/onnx/onnx/blob/main/onnx/onnx.proto#L515
however, onnx does have BFloat16, as seen here:
https://github.com/onnx/onnx/blob/main/onnx/onnx.proto#L526
I consider this a bug, rather than a feature request, because these are not interchangeable types.
Urgency
System information
- OS Platform and Distribution (e.g., Linux Ubuntu 18.04*):
- TensorFlow Version:
- Python version:
- ONNX version (if applicable, e.g. 1.11*):
- ONNXRuntime version (if applicable, e.g. 1.11*):
To Reproduce
Screenshots
Additional context
The following workaround seems to work for my limited purposes:
tf2onnx.tf_utils.TF_TO_ONNX_DTYPE[types_pb2.DT_BFLOAT16] = onnx_pb.TensorProto.BFLOAT16
fn_onnx, _ = tf2onnx.convert.from_function(
fn,
input_signature=input_signature,
extra_opset=extra_opset,
)
Unclear whether this has unforeseen consequences, though, for more complicated graph conversions than my own.