model-optimization icon indicating copy to clipboard operation
model-optimization copied to clipboard

PCQAT not working if Conv2D has kernel size 1x1

Open YannPourcenoux opened this issue 3 years ago • 8 comments

Describe the bug When doing the Sparsity and cluster preserving quantization aware training (PCQAT) Keras example, if I use a Conv2D Layer with a kernel size of (1, 1) the model after the QAT step of PCQAT has only zeros in this weight. It works fine if the kernel size is (3, 3) or bigger.

System information

TensorFlow version (installed from source or binary): 2.9.1 from pip

TensorFlow Model Optimization version (installed from source or binary): 0.7.2 from pip

Python version: 3.9.12

Describe the expected behavior Having the same behavior whether the kernel size of the convolutional layer is (1, 1) or (3, 3).

Describe the current behavior The final model only has 0 in the weights after doing the QAT step which preserves sparsity and clustering. The sparsity is 100% as shown by the print()

PCQAT Model sparsity:
conv2d/kernel:0: 100.00% sparsity  (16/16)
dense/kernel:0: 61.98% sparsity  (19436/31360)

Code to reproduce the issue

import os
import tempfile
import zipfile

import numpy as np
import tensorflow as tf
import tensorflow_model_optimization as tfmot
from tensorflow_model_optimization.python.core.clustering.keras.experimental import cluster

# Load MNIST dataset
mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images = test_images / 255.0

model = tf.keras.Sequential([
    tf.keras.layers.InputLayer(input_shape=(28, 28)),
    tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
    # If the kernel size of the convolution layer below is (3, 3) as in the tutorial then everything
    # is working as expected
    tf.keras.layers.Conv2D(filters=16, kernel_size=(1, 1), activation=tf.nn.relu),
    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(10)
])

opt = tf.keras.optimizers.Adam(learning_rate=1e-3)

# Train the digit classification model
model.compile(
    optimizer=opt,
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy']
)

model.fit(train_images, train_labels, validation_split=0.1, epochs=10)

_, baseline_model_accuracy = model.evaluate(test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)

_, keras_file = tempfile.mkstemp('.h5')
print('Saving model to: ', keras_file)
tf.keras.models.save_model(model, keras_file, include_optimizer=False)

prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

pruning_params = {
    'pruning_schedule': tfmot.sparsity.keras.ConstantSparsity(0.5, begin_step=0, frequency=100)
}

callbacks = [tfmot.sparsity.keras.UpdatePruningStep()]

pruned_model = prune_low_magnitude(model, **pruning_params)

# Use smaller learning rate for fine-tuning
opt = tf.keras.optimizers.Adam(learning_rate=1e-5)

pruned_model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=opt,
    metrics=['accuracy']
)

# Fine-tune model
pruned_model.fit(train_images, train_labels, epochs=3, validation_split=0.1, callbacks=callbacks)


def print_model_weights_sparsity(model):
    for layer in model.layers:
        if isinstance(layer, tf.keras.layers.Wrapper):
            weights = layer.trainable_weights
        else:
            weights = layer.weights
        for weight in weights:
            if "kernel" not in weight.name or "centroid" in weight.name:
                continue
            weight_size = weight.numpy().size
            zero_num = np.count_nonzero(weight == 0)
            print(
                f"{weight.name}: {zero_num/weight_size:.2%} sparsity ",
                f"({zero_num}/{weight_size})",
            )


def print_model_weight_clusters(model):
    for layer in model.layers:
        if isinstance(layer, tf.keras.layers.Wrapper):
            weights = layer.trainable_weights
        else:
            weights = layer.weights
        for weight in weights:
            # ignore auxiliary quantization weights
            if "quantize_layer" in weight.name:
                continue
            if "kernel" in weight.name:
                unique_count = len(np.unique(weight))
                print(f"{layer.name}/{weight.name}: {unique_count} clusters ")


stripped_pruned_model = tfmot.sparsity.keras.strip_pruning(pruned_model)

print_model_weights_sparsity(stripped_pruned_model)

cluster_weights = tfmot.clustering.keras.cluster_weights
CentroidInitialization = tfmot.clustering.keras.CentroidInitialization

cluster_weights = cluster.cluster_weights

clustering_params = {
    'number_of_clusters': 8,
    'cluster_centroids_init': CentroidInitialization.KMEANS_PLUS_PLUS,
    'preserve_sparsity': True
}

sparsity_clustered_model = cluster_weights(stripped_pruned_model, **clustering_params)

sparsity_clustered_model.compile(
    optimizer='adam',
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy']
)

print('Train sparsity preserving clustering model:')
sparsity_clustered_model.fit(train_images, train_labels, epochs=3, validation_split=0.1)

stripped_clustered_model = tfmot.clustering.keras.strip_clustering(sparsity_clustered_model)

print("Model sparsity:\n")
print_model_weights_sparsity(stripped_clustered_model)

print("\nModel clusters:\n")
print_model_weight_clusters(stripped_clustered_model)

# QAT
qat_model = tfmot.quantization.keras.quantize_model(stripped_clustered_model)

qat_model.compile(
    optimizer='adam',
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy']
)
print('Train qat model:')
qat_model.fit(train_images, train_labels, batch_size=128, epochs=1, validation_split=0.1)

# PCQAT
quant_aware_annotate_model = tfmot.quantization.keras.quantize_annotate_model(
    stripped_clustered_model
)
pcqat_model = tfmot.quantization.keras.quantize_apply(
    quant_aware_annotate_model,
    tfmot.experimental.combine.Default8BitClusterPreserveQuantizeScheme(preserve_sparsity=True)
)

pcqat_model.compile(
    optimizer='adam',
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy']
)
print('Train pcqat model:')
pcqat_model.fit(train_images, train_labels, batch_size=128, epochs=1, validation_split=0.1)

print("QAT Model clusters:")
print_model_weight_clusters(qat_model)
print("\nQAT Model sparsity:")
print_model_weights_sparsity(qat_model)
print("\nPCQAT Model clusters:")
print_model_weight_clusters(pcqat_model)
print("\nPCQAT Model sparsity:")
print_model_weights_sparsity(pcqat_model)


def get_gzipped_model_size(file):
    # It returns the size of the gzipped model in kilobytes.

    _, zipped_file = tempfile.mkstemp('.zip')
    with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
        f.write(file)

    return os.path.getsize(zipped_file) / 1000


# QAT model
converter = tf.lite.TFLiteConverter.from_keras_model(qat_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
qat_tflite_model = converter.convert()
qat_model_file = 'qat_model.tflite'
# Save the model.
with open(qat_model_file, 'wb') as f:
    f.write(qat_tflite_model)

# PCQAT model
converter = tf.lite.TFLiteConverter.from_keras_model(pcqat_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
pcqat_tflite_model = converter.convert()
pcqat_model_file = 'pcqat_model.tflite'
# Save the model.
with open(pcqat_model_file, 'wb') as f:
    f.write(pcqat_tflite_model)

print("QAT model size: ", get_gzipped_model_size(qat_model_file), ' KB')
print("PCQAT model size: ", get_gzipped_model_size(pcqat_model_file), ' KB')


def eval_model(interpreter):
    input_index = interpreter.get_input_details()[0]["index"]
    output_index = interpreter.get_output_details()[0]["index"]

    # Run predictions on every image in the "test" dataset.
    prediction_digits = []
    for i, test_image in enumerate(test_images):
        if i % 1000 == 0:
            print(f"Evaluated on {i} results so far.")
        # Pre-processing: add batch dimension and convert to float32 to match with
        # the model's input data format.
        test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
        interpreter.set_tensor(input_index, test_image)

        # Run inference.
        interpreter.invoke()

        # Post-processing: remove batch dimension and find the digit with highest
        # probability.
        output = interpreter.tensor(output_index)
        digit = np.argmax(output()[0])
        prediction_digits.append(digit)

    print('\n')
    # Compare prediction results with ground truth labels to calculate accuracy.
    prediction_digits = np.array(prediction_digits)
    accuracy = (prediction_digits == test_labels).mean()
    return accuracy


interpreter = tf.lite.Interpreter(pcqat_model_file)
interpreter.allocate_tensors()
pcqat_test_accuracy = eval_model(interpreter)

interpreter = tf.lite.Interpreter(qat_model_file)
interpreter.allocate_tensors()
qat_test_accuracy = eval_model(interpreter)

print('Pruned, clustered and quantized TFLite test_accuracy:', pcqat_test_accuracy)
print('quantized TFLite test_accuracy:', qat_test_accuracy)
print('Baseline TF test accuracy:', baseline_model_accuracy)

YannPourcenoux avatar Jun 09 '22 21:06 YannPourcenoux

@rino20 Hi Rino, could you help fix this issue?

inho9606 avatar Jun 14 '22 01:06 inho9606

Can I get an update on this? This is one of your tutorials and 1x1 Convs are one of the most used layers in Deep Learning for computer vision

YannPourcenoux avatar Jun 22 '22 08:06 YannPourcenoux

@wwwind Could you take a look? Thanks.

rino20 avatar Jun 22 '22 12:06 rino20

Hi @rino20 Yes, we will take a look today/tomorrow at this issue.

wwwind avatar Jun 22 '22 13:06 wwwind

Hi @YannPourcenoux I'm taking a look at this now, also. Will keep you posted.

jamwar01 avatar Jun 22 '22 14:06 jamwar01

@YannPourcenoux We have found the source of the problem and are now working towards a solution.

jamwar01 avatar Jun 23 '22 09:06 jamwar01

Great! Thanks! Looking forward to hearing from you again 😁

YannPourcenoux avatar Jun 23 '22 09:06 YannPourcenoux

Hi @YannPourcenoux a PR has been created for the issue (linked above). In the meantime, if you like, while waiting for it to be merged, you may download the patch and see for yourself if you get the desired behaviour with the 1x1 kernel sizes. Thanks for drawing attention to this bug :)

jamwar01 avatar Jun 30 '22 17:06 jamwar01