model-optimization icon indicating copy to clipboard operation
model-optimization copied to clipboard

QAT support for LayerNormalization

Open tom-arm opened this issue 3 years ago • 0 comments

System information

  • TensorFlow version (you are using): 2.8
  • Are you willing to contribute it (Yes/No): Yes

Motivation This would be beneficial for models that use this layer - this is for example used in Transformer models.

Describe the feature Be able to run QAT on a model with the LayerNormalization layer.

Describe how existing APIs don't satisfy your use case (optional if obvious) As an example, the following code snippet will fail:

import tensorflow as tf
import tensorflow_model_optimization as tfmot

from tensorflow import keras

model = keras.Sequential([
    keras.layers.InputLayer(input_shape=(28, 28)),
    keras.layers.Reshape(target_shape=(28, 28, 1)),
    keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
    keras.layers.LayerNormalization(axis=3),
    keras.layers.MaxPooling2D(pool_size=(2, 2)),
    keras.layers.Flatten(),
    keras.layers.Dense(10)
])

quant_model = tfmot.quantization.keras.quantize_model(model)

tom-arm avatar Mar 11 '22 16:03 tom-arm