tensorflow-onnx icon indicating copy to clipboard operation
tensorflow-onnx copied to clipboard

False warning about unsupported doubles in runtime GEMM operations

Open kmooney47 opened this issue 1 year ago • 2 comments

Describe the bug When writing a TF/keras model trained w/ with F64, tf2onnx warns about a lack of float64 support for GEMM by the runtime:

onnx_model, _ = tf2onnx.convert.from_keras(model, opset=14)
onnx.save(onnx_model, "./testF64.onnx") 

WARNING:root:For now, onnxruntime only support float32 type for Gemm rewriter From: https://github.com/onnx/tensorflow-onnx/blob/main/tf2onnx/rewriter/gemm_rewriter.py#L74

As far as I can tell ORT indeed does support double GEMM and was implemented about a year after this warning was added to the tf2onnx writer. I'm guessing that this warning was just never updated following the addition of double support. I've tested the exported F64 onnx models and they appear to be using doubles throughout the computation.

See this commit to the runtime repo adding double support: https://github.com/microsoft/onnxruntime/commit/5968a91ea66d7a1f0568b042b8393f0b4c9c8a55

Based on this I believe we can remove that warning, or at least update it to represent a more current runtime support slate.

Urgency LOW

To Reproduce Export a model from keras trained in F64 using tf2onnx

kmooney47 avatar Aug 21 '24 15:08 kmooney47

Did you try to remove the line in gemm_rewriter.py and export your model successfully?

fatcat-z avatar Aug 22 '24 02:08 fatcat-z

Did you try to remove the line in gemm_rewriter.py and export your model successfully?

The model still exports without issue as the line sends back a warning not an error.

kmooney47 avatar Aug 26 '24 15:08 kmooney47