model-optimization icon indicating copy to clipboard operation
model-optimization copied to clipboard

Error in MovingAverageQuantizer with per_axis=True due to missing parameters in _add_range_weights

Open Litschi123 opened this issue 1 year ago • 0 comments

Describe the bug Using the MovingAverageQuantizer with parameter per_axis set to True results in this error:

tensorflow_model_optimization\python\core\quantization\keras\quant_ops.py", line 335, in _FakeQuantWithMinMaxVars * assert len(min_var.get_shape()) == 1

It's caused by this helper function _add_range_weights called by the build function of the MovingAverageQuantizer where the per_axis and tensor_shape parameters are not passed on resulting in only initializing a scalar for the min/max variables. Which later fails the assert.

System information

TensorFlow version: 2.12.0 (binary)

TensorFlow Model Optimization version: 0.7.4 (binary)

Python version: 3.9.16

Describe the expected behavior There should be a list of values in the min/max variables.

Describe the current behavior Throws and error because of failed assert in _FakeQuantWithMinMaxVars

Code to reproduce the issue

import keras
import tensorflow_model_optimization as tfmot

class CustomConv2DQuantizeConfig(tfmot.quantization.keras.QuantizeConfig):

  def get_weights_and_quantizers(self, layer):
    return [(layer.kernel, tfmot.quantization.keras.quantizers.MovingAverageQuantizer(num_bits=8, per_axis=True, symmetric=True, narrow_range=True))]

  def get_activations_and_quantizers(self, layer):
    return [(layer.activation, tfmot.quantization.keras.quantizers.MovingAverageQuantizer(num_bits=8, per_axis=False, symmetric=False, narrow_range=False))]

  def set_quantize_weights(self, layer, quantize_weights):
    layer.kernel = quantize_weights[0]

  def set_quantize_activations(self, layer, quantize_activations):
    layer.activation = quantize_activations[0]

  def get_output_quantizers(self, layer):
    return []

  def get_config(self):
    return {}

annotated_model = keras.Sequential([
    keras.Input(shape=(32,32,3)),
    tfmot.quantization.keras.quantize_annotate_layer(keras.layers.Conv2D(32, kernel_size=(3, 3)), CustomConv2DQuantizeConfig())
])

with tfmot.quantization.keras.quantize_scope(
    {'CustomConv2DQuantizeConfig': CustomConv2DQuantizeConfig}):
    quant_aware_model = tfmot.quantization.keras.quantize_apply(annotated_model)

Additional context Changing the build function of MovingAverageQuantizer to pass the mentioned parameters like:

def build(self, tensor_shape, name, layer):
    return self._add_range_weights(layer, name, self.per_axis, tensor_shape)

fixes the issue and results in the expected behavior.

Litschi123 avatar Jan 23 '24 18:01 Litschi123