Conformer
Conformer copied to clipboard
tflite conversion fails.
Describe the bug
File "/home/modelparser/Conformer/conformer_tf/conformer_tf.py", line 168, in call
inputs = self.conv(inputs) + inputs
File "/home/modelparser/Conformer/conformer_tf/conformer_tf.py", line 128, in call
return self.net(inputs)
File "/home/modelparser/Conformer/conformer_tf/conformer_tf.py", line 89, in call
return tf.keras.layers.BatchNormalization(axis=-1)(inputs)
ValueError: Exception encountered when calling layer 'batch_norm' (type BatchNorm).
tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.
Call arguments received by layer 'batch_norm' (type BatchNorm):
• inputs=tf.Tensor(shape=(1, 1024, 1024), dtype=float32)
when i tried to convert .h5 to .tflite using convformer block, i got the above message. it caused by BatchNorm class. i fix BatchNorm class like this. and fixed it.
class BatchNorm(tf.keras.layers.Layer):
def __init__(self, causal, **kwargs):
super(BatchNorm, self).__init__(**kwargs)
self.causal = causal
self.bnorm = tf.keras.layers.BatchNormalization(axis=-1)
def call(self, inputs):
if not self.causal:
return self.bnorm(inputs)
return tf.identity(inputs)
To Reproduce I referred to tflite code in official site
def conformer():
input_layer = tf.keras.layers.Input(shape=(1024, 512), batch_size=1)
conformer_block = ConformerBlock(
dim = 512,
dim_head = 64,
heads = 8,
ff_mult = 4,
conv_expansion_factor = 2,
conv_kernel_size = 31,
attn_dropout = 0.,
ff_dropout = 0.,
conv_dropout = 0.
)(input_layer)
return tf.keras.Model(inputs=input_layer, outputs=conformer_block)
def representative_dataset():
for _ in range(100):
data = np.random.rand(1, 1024, 512)
yield [data.astype(np.float32)]
net = conformer()
converter = tf.lite.TFLiteConverter.from_keras_model(net)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.representative_dataset = representative_dataset
tflite_model = converter.convert()
with open("conformer.tflite", "wb+") as tflite_file:
tflite_file.write(tflite_model)
Desktop (please complete the following information):
- OS: Ubuntu 20.04.6 LTS
Smartphone (please complete the following information):
Additional context