tensorflow-onnx
tensorflow-onnx copied to clipboard
Share constant between multiple Mul operators
Hello,
I don't know if this is a bug or a feature in tf2onnx. When I convert a tensorflow model which contains multiple residual blocks, the constant in the Multiply operators will be shared between multiple Mul operators in onnx. You can see the demonstration in the following figures. This causes a weird output when I run the inference on the onnx model, compared to the output of tensorflow.
The simple code to reproduce the issue:
import json
import tensorflow as tf
import keras
import tf2onnx
import onnx
INPUT_SHAPE = (1, 128, 128, 3)
def residual_block(input_layer):
num_feature = input_layer.shape[-1]
conv = keras.layers.Conv2D(filters=num_feature, kernel_size=3, padding='same', bias_initializer='ones')(input_layer)
return input_layer + 0.2*conv
def sample_network(input_layer):
x = keras.layers.Conv2D(filters=16, kernel_size=3, padding="same", bias_initializer='ones')(input_layer)
for _ in range(3):
x = residual_block(x)
x = keras.layers.LeakyReLU()(x)
return x
def build_sample_network(input_shape):
input_net = tf.keras.Input(shape=input_shape, dtype=tf.float32, name="input")
output = sample_network(input_net)
model = tf.keras.Model(inputs=input_net, outputs=output)
model_json = json.loads(model.to_json())
with open("sample_network.json", "w") as f:
json.dump(model_json, f, indent=2)
return model
def build_sample_onnx():
# build TF model
model = build_sample_network(input_shape=INPUT_SHAPE[1:])
# convert tf to onnx model
spec = (tf.TensorSpec((None, *INPUT_SHAPE[1:]), tf.float32, name="input"),)
onnx_model, _ = tf2onnx.convert.from_keras(model, input_signature=spec, opset=12)
onnx.save_model(onnx_model, "sample_network.onnx")
if __name__ == "__main__":
build_sample_onnx()
Could you please confirm if it is an issue in tf2onnx? Thank you, Viet
Could you please share the weird output you mentioned? I've tried your code with tf 2.9.1+tf2onnx 1.12, the results between tensorflow and onnxruntime are same so I could not repro your issue.