model-optimization
model-optimization copied to clipboard
`tf.split` or `tf.transpose` cause errors for quantize-aware training with `quantize_apply`
Describe the bug
We are trying to implement some network like ShuffleNetV2 but encounter some error when quantize_apply
the model.
I believe ShuffleNet or related ideas are popular in edge devices, please kindly help us to resolve this proble.
Any advice is welcome.
System information
TensorFlow version (installed from source or binary): 2.7.0
TensorFlow Model Optimization version (installed from source or binary): 0.7.0
Python version: 3.8.13
Describe the expected behavior
Just add quantization-aware operator in to the model.
Describe the current behavior
When running the provided code, either the tf.transpose
or tf.split
will cause error to Tensorflow Model Optimization.
The error message due to tf.split
before convolution layers:
ValueError: Exception encountered when calling layer "bn3" (type BatchNormalization).
Shape must be rank 4 but is rank 5 for '{{node bn3/FusedBatchNormV3}} = FusedBatchNormV3[T=DT_FLOAT, U=DT_FLOAT, data_format="NHWC", epsilon=0.001, exponential_avg_factor=1, is_training=false](Placeholder, bn3/ReadVariableOp, bn3/ReadVariableOp_1, bn3/FusedBatchNormV3/ReadVariableOp, bn3/FusedBatchNormV3/ReadVariableOp_1)' with input shapes: [1,?,128,128,32], [32], [32], [32], [32].
The error message due to tf.transpose
:
ValueError: Exception encountered when calling layer "tf.compat.v1.transpose" (type TFOpLambda).
Dimension must be 6 but is 5 for '{{node tf.compat.v1.transpose/transpose}} = Transpose[T=DT_FLOAT, Tperm=DT_INT32](tf.compat.v1.transpose/transpose/a, tf.compat.v1.transpose/transpose/perm)' with input shapes: [1,?,128,128,2,32], [5].
Code to reproduce the issue
Just run the following code you will get the error message due to tf.split
.
from __future__ import annotations
from typing import Callable, Optional
import tensorflow as tf
import tensorflow_model_optimization as tfmot
from tensorflow.keras import layers
SKIP_LAYER = [
"resize",
"Resize",
"reshape",
"Reshape",
"concat",
"Concat" "ExpandDims",
"Repeats",
"Shape",
"strided_slice",
"Tile",
]
def quantize_model(
model: tf.keras.Model,
annotate: Optional[Callable] = None,
quantize_scope: Optional[dict[str, tf.keras.layers.Layer]] = None,
) -> tf.keras.Model:
quantize_scope = {} if quantize_scope is None else quantize_scope
def annotate(layer):
if any([name in layer.name for name in SKIP_LAYER]):
return layer
else:
return tfmot.quantization.keras.quantize_annotate_layer(layer)
anno_model = tf.keras.models.clone_model(model, clone_function=annotate)
with tfmot.quantization.keras.quantize_scope(quantize_scope):
model = tfmot.quantization.keras.quantize_apply(anno_model)
return model
def channel_shuffle(tensor: tf.Tensor, groups: int = 2) -> tf.Tensor:
"""Channel shuffle operation."""
_, height, width, num_channels = tensor.shape.as_list()
assert num_channels % groups == 0
tensor = tf.reshape(tensor, [-1, height, width, groups, num_channels // groups])
tensor = tf.transpose(tensor, [0, 1, 2, 4, 3])
tensor = tf.identity(tensor, name="channel_shuffle")
tensor = tf.reshape(tensor, [-1, height, width, num_channels])
return tensor
def simple_nn(img_input: tf.Tensor) -> tf.Tensor:
latent = layers.Conv2D(32, 1, padding="same", use_bias=False, name="conv1")(img_input)
latent = layers.BatchNormalization(name="bn1")(latent)
latent = layers.ReLU(name="relu1")(latent)
latent = layers.DepthwiseConv2D(3, 1, padding="same", name="conv2")(img_input)
latent = layers.BatchNormalization(name="bn2")(latent)
latent = layers.Conv2D(32, 1, padding="same", use_bias=False, name="conv3")(img_input)
latent = layers.BatchNormalization(name="bn3")(latent)
latent = layers.ReLU(name="relu3")(latent)
return latent
def split_like_nn(img_input: tf.Tensor) -> tf.Tensor:
latent = layers.Conv2D(64, 1, padding="same", use_bias=False, name="conv0")(img_input)
latent = layers.BatchNormalization(name="bn0")(latent)
latent = layers.ReLU(name="relu0")(latent)
latent_0, latent_1 = tf.split(latent, 2, axis=-1)
latent_0 = simple_nn(latent_0)
latent = tf.concat([latent_0, latent_1], axis=-1)
latent = channel_shuffle(latent)
return latent
if __name__ == "__main__":
img_input = tf.keras.Input((128, 128, 1), dtype=tf.float32, name="img")
outputs = split_like_nn(img_input)
model = tf.keras.Model(inputs=img_input, outputs=outputs, name="PoseNetV2")
model.summary()
model_qat = quantize_model(model)
model_qat.summary()
You can just comment the following three lines of code will get the error message from tf.transpose
.
latent_0, latent_1 = tf.split(latent, 2, axis=-1)
latent_0 = simple_nn(latent_0)
latent = tf.concat([latent_0, latent_1], axis=-1)
Hi!
I'm also suffering from the same error using tf.split
Is there any fix coming soon?
Hi, I'm getting the same error too with tf.transpose and tf.permute, any update on a solution?
Hi, tf.nn.depthtospace causes the same error. I'd be very happy about any advice how to solve this :)
I think I had the same issue.
I could overcome this error by wrapping tf.split()
in a keras layer:
@keras.saving.register_keras_serializable(package="MyLayers", name="SplitLayer")
class SplitLayer(keras.layers.Layer):
def __init__(self, num_or_size_splits, axis, **kwargs):
super(SplitLayer, self).__init__(**kwargs)
self.num_or_size_splits = num_or_size_splits
self.axis = axis
def call(self, inputs):
return tf.split(inputs, self.num_or_size_splits, axis=self.axis)
def get_config(self):
config = super(SplitLayer, self).get_config()
config.update({
'num_or_size_splits': self.num_or_size_splits,
'axis': self.axis,
})
return config
I quantized my model (RetinaNet) like this:
def quantize_model(model):
def quantize_annotate(layer):
layer_types_to_avoid = (kcv.layers.AnchorGenerator, kcv.models.retinanet.LabelEncoder, kcv.layers.NonMaxSuppression, my_retinanet.MyPredictionDecoder)
if isinstance(layer, layer_types_to_avoid) or "split" in layer.name:
return layer
return tfmot.quantization.keras.quantize_annotate_layer(layer)
annotated_model = tf.keras.models.clone_model(
model,
clone_function=quantize_annotate,
)
with tfmot.quantization.keras.quantize_scope():
quantized_model = tfmot.quantization.keras.quantize_apply(annotated_model)
return quantized_model
tf version: 2.15.1 keras version: 2.15.0 tfmot version: 0.7.5