PCQAT not working if Conv2D has kernel size 1x1
Describe the bug
When doing the Sparsity and cluster preserving quantization aware training (PCQAT) Keras example, if I use a Conv2D Layer with a kernel size of (1, 1) the model after the QAT step of PCQAT has only zeros in this weight.
It works fine if the kernel size is (3, 3) or bigger.
System information
TensorFlow version (installed from source or binary): 2.9.1 from pip
TensorFlow Model Optimization version (installed from source or binary): 0.7.2 from pip
Python version: 3.9.12
Describe the expected behavior
Having the same behavior whether the kernel size of the convolutional layer is (1, 1) or (3, 3).
Describe the current behavior
The final model only has 0 in the weights after doing the QAT step which preserves sparsity and clustering. The sparsity is 100% as shown by the print()
PCQAT Model sparsity:
conv2d/kernel:0: 100.00% sparsity (16/16)
dense/kernel:0: 61.98% sparsity (19436/31360)
Code to reproduce the issue
import os
import tempfile
import zipfile
import numpy as np
import tensorflow as tf
import tensorflow_model_optimization as tfmot
from tensorflow_model_optimization.python.core.clustering.keras.experimental import cluster
# Load MNIST dataset
mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images = test_images / 255.0
model = tf.keras.Sequential([
tf.keras.layers.InputLayer(input_shape=(28, 28)),
tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
# If the kernel size of the convolution layer below is (3, 3) as in the tutorial then everything
# is working as expected
tf.keras.layers.Conv2D(filters=16, kernel_size=(1, 1), activation=tf.nn.relu),
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(10)
])
opt = tf.keras.optimizers.Adam(learning_rate=1e-3)
# Train the digit classification model
model.compile(
optimizer=opt,
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
model.fit(train_images, train_labels, validation_split=0.1, epochs=10)
_, baseline_model_accuracy = model.evaluate(test_images, test_labels, verbose=0)
print('Baseline test accuracy:', baseline_model_accuracy)
_, keras_file = tempfile.mkstemp('.h5')
print('Saving model to: ', keras_file)
tf.keras.models.save_model(model, keras_file, include_optimizer=False)
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude
pruning_params = {
'pruning_schedule': tfmot.sparsity.keras.ConstantSparsity(0.5, begin_step=0, frequency=100)
}
callbacks = [tfmot.sparsity.keras.UpdatePruningStep()]
pruned_model = prune_low_magnitude(model, **pruning_params)
# Use smaller learning rate for fine-tuning
opt = tf.keras.optimizers.Adam(learning_rate=1e-5)
pruned_model.compile(
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=opt,
metrics=['accuracy']
)
# Fine-tune model
pruned_model.fit(train_images, train_labels, epochs=3, validation_split=0.1, callbacks=callbacks)
def print_model_weights_sparsity(model):
for layer in model.layers:
if isinstance(layer, tf.keras.layers.Wrapper):
weights = layer.trainable_weights
else:
weights = layer.weights
for weight in weights:
if "kernel" not in weight.name or "centroid" in weight.name:
continue
weight_size = weight.numpy().size
zero_num = np.count_nonzero(weight == 0)
print(
f"{weight.name}: {zero_num/weight_size:.2%} sparsity ",
f"({zero_num}/{weight_size})",
)
def print_model_weight_clusters(model):
for layer in model.layers:
if isinstance(layer, tf.keras.layers.Wrapper):
weights = layer.trainable_weights
else:
weights = layer.weights
for weight in weights:
# ignore auxiliary quantization weights
if "quantize_layer" in weight.name:
continue
if "kernel" in weight.name:
unique_count = len(np.unique(weight))
print(f"{layer.name}/{weight.name}: {unique_count} clusters ")
stripped_pruned_model = tfmot.sparsity.keras.strip_pruning(pruned_model)
print_model_weights_sparsity(stripped_pruned_model)
cluster_weights = tfmot.clustering.keras.cluster_weights
CentroidInitialization = tfmot.clustering.keras.CentroidInitialization
cluster_weights = cluster.cluster_weights
clustering_params = {
'number_of_clusters': 8,
'cluster_centroids_init': CentroidInitialization.KMEANS_PLUS_PLUS,
'preserve_sparsity': True
}
sparsity_clustered_model = cluster_weights(stripped_pruned_model, **clustering_params)
sparsity_clustered_model.compile(
optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
print('Train sparsity preserving clustering model:')
sparsity_clustered_model.fit(train_images, train_labels, epochs=3, validation_split=0.1)
stripped_clustered_model = tfmot.clustering.keras.strip_clustering(sparsity_clustered_model)
print("Model sparsity:\n")
print_model_weights_sparsity(stripped_clustered_model)
print("\nModel clusters:\n")
print_model_weight_clusters(stripped_clustered_model)
# QAT
qat_model = tfmot.quantization.keras.quantize_model(stripped_clustered_model)
qat_model.compile(
optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
print('Train qat model:')
qat_model.fit(train_images, train_labels, batch_size=128, epochs=1, validation_split=0.1)
# PCQAT
quant_aware_annotate_model = tfmot.quantization.keras.quantize_annotate_model(
stripped_clustered_model
)
pcqat_model = tfmot.quantization.keras.quantize_apply(
quant_aware_annotate_model,
tfmot.experimental.combine.Default8BitClusterPreserveQuantizeScheme(preserve_sparsity=True)
)
pcqat_model.compile(
optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
print('Train pcqat model:')
pcqat_model.fit(train_images, train_labels, batch_size=128, epochs=1, validation_split=0.1)
print("QAT Model clusters:")
print_model_weight_clusters(qat_model)
print("\nQAT Model sparsity:")
print_model_weights_sparsity(qat_model)
print("\nPCQAT Model clusters:")
print_model_weight_clusters(pcqat_model)
print("\nPCQAT Model sparsity:")
print_model_weights_sparsity(pcqat_model)
def get_gzipped_model_size(file):
# It returns the size of the gzipped model in kilobytes.
_, zipped_file = tempfile.mkstemp('.zip')
with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
f.write(file)
return os.path.getsize(zipped_file) / 1000
# QAT model
converter = tf.lite.TFLiteConverter.from_keras_model(qat_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
qat_tflite_model = converter.convert()
qat_model_file = 'qat_model.tflite'
# Save the model.
with open(qat_model_file, 'wb') as f:
f.write(qat_tflite_model)
# PCQAT model
converter = tf.lite.TFLiteConverter.from_keras_model(pcqat_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
pcqat_tflite_model = converter.convert()
pcqat_model_file = 'pcqat_model.tflite'
# Save the model.
with open(pcqat_model_file, 'wb') as f:
f.write(pcqat_tflite_model)
print("QAT model size: ", get_gzipped_model_size(qat_model_file), ' KB')
print("PCQAT model size: ", get_gzipped_model_size(pcqat_model_file), ' KB')
def eval_model(interpreter):
input_index = interpreter.get_input_details()[0]["index"]
output_index = interpreter.get_output_details()[0]["index"]
# Run predictions on every image in the "test" dataset.
prediction_digits = []
for i, test_image in enumerate(test_images):
if i % 1000 == 0:
print(f"Evaluated on {i} results so far.")
# Pre-processing: add batch dimension and convert to float32 to match with
# the model's input data format.
test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
interpreter.set_tensor(input_index, test_image)
# Run inference.
interpreter.invoke()
# Post-processing: remove batch dimension and find the digit with highest
# probability.
output = interpreter.tensor(output_index)
digit = np.argmax(output()[0])
prediction_digits.append(digit)
print('\n')
# Compare prediction results with ground truth labels to calculate accuracy.
prediction_digits = np.array(prediction_digits)
accuracy = (prediction_digits == test_labels).mean()
return accuracy
interpreter = tf.lite.Interpreter(pcqat_model_file)
interpreter.allocate_tensors()
pcqat_test_accuracy = eval_model(interpreter)
interpreter = tf.lite.Interpreter(qat_model_file)
interpreter.allocate_tensors()
qat_test_accuracy = eval_model(interpreter)
print('Pruned, clustered and quantized TFLite test_accuracy:', pcqat_test_accuracy)
print('quantized TFLite test_accuracy:', qat_test_accuracy)
print('Baseline TF test accuracy:', baseline_model_accuracy)
@rino20 Hi Rino, could you help fix this issue?
Can I get an update on this? This is one of your tutorials and 1x1 Convs are one of the most used layers in Deep Learning for computer vision
@wwwind Could you take a look? Thanks.
Hi @rino20 Yes, we will take a look today/tomorrow at this issue.
Hi @YannPourcenoux I'm taking a look at this now, also. Will keep you posted.
@YannPourcenoux We have found the source of the problem and are now working towards a solution.
Great! Thanks! Looking forward to hearing from you again 😁
Hi @YannPourcenoux a PR has been created for the issue (linked above). In the meantime, if you like, while waiting for it to be merged, you may download the patch and see for yourself if you get the desired behaviour with the 1x1 kernel sizes. Thanks for drawing attention to this bug :)