tensorflow-onnx icon indicating copy to clipboard operation
tensorflow-onnx copied to clipboard

Properly Support BFloat16

Open AndrewJBean opened this issue 1 year ago • 1 comments

Describe the bug Currently TF BFloat16 data type maps to Float16:

https://github.com/onnx/tensorflow-onnx/blob/main/tf2onnx/tf_utils.py#L31

and this is seen in onnx graphs generated by tf2onnx, where e.g. cast to bfloat16 and cast to float16 both result in elem_type: 10, which is float16, as seen here:

https://github.com/onnx/onnx/blob/main/onnx/onnx.proto#L515

however, onnx does have BFloat16, as seen here:

https://github.com/onnx/onnx/blob/main/onnx/onnx.proto#L526

I consider this a bug, rather than a feature request, because these are not interchangeable types.

Urgency

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 18.04*):
  • TensorFlow Version:
  • Python version:
  • ONNX version (if applicable, e.g. 1.11*):
  • ONNXRuntime version (if applicable, e.g. 1.11*):

To Reproduce

Screenshots

Additional context

AndrewJBean avatar Jun 06 '24 17:06 AndrewJBean

The following workaround seems to work for my limited purposes:

  tf2onnx.tf_utils.TF_TO_ONNX_DTYPE[types_pb2.DT_BFLOAT16] = onnx_pb.TensorProto.BFLOAT16
  fn_onnx, _ = tf2onnx.convert.from_function(
    fn,
    input_signature=input_signature,
    extra_opset=extra_opset,
  )

Unclear whether this has unforeseen consequences, though, for more complicated graph conversions than my own.

AndrewJBean avatar Jun 07 '24 16:06 AndrewJBean