model-optimization icon indicating copy to clipboard operation
model-optimization copied to clipboard

`tf.split` or `tf.transpose` cause errors for quantize-aware training with `quantize_apply`

Open Janus-Shiau opened this issue 1 year ago • 4 comments

Describe the bug

We are trying to implement some network like ShuffleNetV2 but encounter some error when quantize_apply the model.

image

I believe ShuffleNet or related ideas are popular in edge devices, please kindly help us to resolve this proble.

Any advice is welcome.

System information

TensorFlow version (installed from source or binary): 2.7.0

TensorFlow Model Optimization version (installed from source or binary): 0.7.0

Python version: 3.8.13

Describe the expected behavior

Just add quantization-aware operator in to the model.

Describe the current behavior

When running the provided code, either the tf.transpose or tf.split will cause error to Tensorflow Model Optimization.

The error message due to tf.split before convolution layers:

ValueError: Exception encountered when calling layer "bn3" (type BatchNormalization).

Shape must be rank 4 but is rank 5 for '{{node bn3/FusedBatchNormV3}} = FusedBatchNormV3[T=DT_FLOAT, U=DT_FLOAT, data_format="NHWC", epsilon=0.001, exponential_avg_factor=1, is_training=false](Placeholder, bn3/ReadVariableOp, bn3/ReadVariableOp_1, bn3/FusedBatchNormV3/ReadVariableOp, bn3/FusedBatchNormV3/ReadVariableOp_1)' with input shapes: [1,?,128,128,32], [32], [32], [32], [32].

The error message due to tf.transpose:

ValueError: Exception encountered when calling layer "tf.compat.v1.transpose" (type TFOpLambda).

Dimension must be 6 but is 5 for '{{node tf.compat.v1.transpose/transpose}} = Transpose[T=DT_FLOAT, Tperm=DT_INT32](tf.compat.v1.transpose/transpose/a, tf.compat.v1.transpose/transpose/perm)' with input shapes: [1,?,128,128,2,32], [5].

Code to reproduce the issue

Just run the following code you will get the error message due to tf.split.

from __future__ import annotations

from typing import Callable, Optional

import tensorflow as tf
import tensorflow_model_optimization as tfmot
from tensorflow.keras import layers


SKIP_LAYER = [
    "resize",
    "Resize",
    "reshape",
    "Reshape",
    "concat",
    "Concat" "ExpandDims",
    "Repeats",
    "Shape",
    "strided_slice",
    "Tile",
]


def quantize_model(
    model: tf.keras.Model,
    annotate: Optional[Callable] = None,
    quantize_scope: Optional[dict[str, tf.keras.layers.Layer]] = None,
) -> tf.keras.Model:
    quantize_scope = {} if quantize_scope is None else quantize_scope

    def annotate(layer):
        if any([name in layer.name for name in SKIP_LAYER]):
            return layer
        else:
            return tfmot.quantization.keras.quantize_annotate_layer(layer)

    anno_model = tf.keras.models.clone_model(model, clone_function=annotate)
    with tfmot.quantization.keras.quantize_scope(quantize_scope):
        model = tfmot.quantization.keras.quantize_apply(anno_model)

    return model


def channel_shuffle(tensor: tf.Tensor, groups: int = 2) -> tf.Tensor:
    """Channel shuffle operation."""
    _, height, width, num_channels = tensor.shape.as_list()
    assert num_channels % groups == 0

    tensor = tf.reshape(tensor, [-1, height, width, groups, num_channels // groups])
    tensor = tf.transpose(tensor, [0, 1, 2, 4, 3])
    tensor = tf.identity(tensor, name="channel_shuffle")

    tensor = tf.reshape(tensor, [-1, height, width, num_channels])
    return tensor


def simple_nn(img_input: tf.Tensor) -> tf.Tensor:
    latent = layers.Conv2D(32, 1, padding="same", use_bias=False, name="conv1")(img_input)
    latent = layers.BatchNormalization(name="bn1")(latent)
    latent = layers.ReLU(name="relu1")(latent)

    latent = layers.DepthwiseConv2D(3, 1, padding="same", name="conv2")(img_input)
    latent = layers.BatchNormalization(name="bn2")(latent)

    latent = layers.Conv2D(32, 1, padding="same", use_bias=False, name="conv3")(img_input)
    latent = layers.BatchNormalization(name="bn3")(latent)
    latent = layers.ReLU(name="relu3")(latent)

    return latent


def split_like_nn(img_input: tf.Tensor) -> tf.Tensor:
    latent = layers.Conv2D(64, 1, padding="same", use_bias=False, name="conv0")(img_input)
    latent = layers.BatchNormalization(name="bn0")(latent)
    latent = layers.ReLU(name="relu0")(latent)

    latent_0, latent_1 = tf.split(latent, 2, axis=-1)
    latent_0 = simple_nn(latent_0)
    latent = tf.concat([latent_0, latent_1], axis=-1)

    latent = channel_shuffle(latent)

    return latent


if __name__ == "__main__":
    img_input = tf.keras.Input((128, 128, 1), dtype=tf.float32, name="img")

    outputs = split_like_nn(img_input)

    model = tf.keras.Model(inputs=img_input, outputs=outputs, name="PoseNetV2")
    model.summary()

    model_qat = quantize_model(model)
    model_qat.summary()

You can just comment the following three lines of code will get the error message from tf.transpose.

 latent_0, latent_1 = tf.split(latent, 2, axis=-1)
 latent_0 = simple_nn(latent_0)
 latent = tf.concat([latent_0, latent_1], axis=-1)

Janus-Shiau avatar Apr 20 '23 09:04 Janus-Shiau

Hi! I'm also suffering from the same error using tf.split Is there any fix coming soon?

guillem-ms avatar May 23 '23 10:05 guillem-ms

Hi, I'm getting the same error too with tf.transpose and tf.permute, any update on a solution?

DerryFitz avatar May 24 '23 10:05 DerryFitz

Hi, tf.nn.depthtospace causes the same error. I'd be very happy about any advice how to solve this :)

or-ims avatar Nov 13 '23 17:11 or-ims

I think I had the same issue. I could overcome this error by wrapping tf.split() in a keras layer:

@keras.saving.register_keras_serializable(package="MyLayers", name="SplitLayer")
class SplitLayer(keras.layers.Layer):
    def __init__(self, num_or_size_splits, axis, **kwargs):
        super(SplitLayer, self).__init__(**kwargs)
        self.num_or_size_splits = num_or_size_splits
        self.axis = axis

    def call(self, inputs):
        return tf.split(inputs, self.num_or_size_splits, axis=self.axis)

    def get_config(self):
        config = super(SplitLayer, self).get_config()
        config.update({
            'num_or_size_splits': self.num_or_size_splits,
            'axis': self.axis,
        })
        return config

I quantized my model (RetinaNet) like this:

def quantize_model(model):        
    def quantize_annotate(layer):
        layer_types_to_avoid = (kcv.layers.AnchorGenerator, kcv.models.retinanet.LabelEncoder, kcv.layers.NonMaxSuppression, my_retinanet.MyPredictionDecoder)
        if isinstance(layer, layer_types_to_avoid) or "split" in layer.name:
            return layer
        return tfmot.quantization.keras.quantize_annotate_layer(layer)

    annotated_model = tf.keras.models.clone_model(
        model,
        clone_function=quantize_annotate,
    )

    with tfmot.quantization.keras.quantize_scope():
        quantized_model = tfmot.quantization.keras.quantize_apply(annotated_model)

    return quantized_model

tf version: 2.15.1 keras version: 2.15.0 tfmot version: 0.7.5

robertatdm avatar Mar 19 '24 17:03 robertatdm