CQAT fails to preserve clusters on ResNet-50
Prior to filing: check that this should be a bug instead of a feature request. Everything supported, including the compatible versions of TensorFlow, is listed in the overview page of each technique. For example, the overview page of quantization-aware training is here. An issue for anything not supported should be a feature request.
Describe the bug CQAT does not preserve clusters. Training ResNet-50 with CIFAR-100
System information
TensorFlow version (installed from source or binary): TensorFlow 2.5
TensorFlow Model Optimization version (installed from source or binary): 0.7.3
Python version: 3.7.13
Describe the expected behavior Model weight clusters are preserved after cluster preserving quantization aware training
Describe the current behavior Model weight clusters are not preserved for some of the kernels after cluster preserving quantization aware training
Code to reproduce the issue Provide a reproducible code that is the bare minimum necessary to generate the problem.
import tempfile
import os
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import datasets
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import zipfile
(ds_train, ds_test), ds_info = tfds.load(
'cifar100',
split=['train', 'test'],
shuffle_files=True,
as_supervised=True,
with_info=True,
)
batch_size = 32
num_val = int(ds_info.splits['train'].num_examples * 0.1)
num_train = ds_info.splits['train'].num_examples - num_val
ds_val = ds_train.take(num_val)
ds_train = ds_train.skip(num_val)
def normalize_img(image, label):
"""Normalizes images: `uint8` -> `float32`."""
return tf.image.convert_image_dtype(image, tf.float32), label
def augment_img(image, label):
image = tf.image.random_flip_left_right(image)
rand = tf.random.uniform([2], minval=0, maxval=1)
if rand[0] > 0.5:
image = tf.image.random_brightness(image, 0.1)
if rand[1] > 0.5:
crop_factor = 0.9
image = tf.image.random_crop(image, (int(32 * crop_factor), int(32 * crop_factor) ,3))
image = tf.image.resize(image, (32, 32))
return image, label
def resize_img(image, label):
image = tf.image.resize(image, (224, 224))
return image, label
ds_train = ds_train.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.map(augment_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.map(resize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.shuffle(num_train)
ds_train = ds_train.batch(batch_size)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)
ds_val = ds_val.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_val = ds_val.map(resize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_val = ds_val.batch(batch_size)
ds_val = ds_val.cache()
ds_val = ds_val.prefetch(tf.data.AUTOTUNE)
ds_test = ds_test.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.map(resize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(batch_size)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)
inputs = tf.keras.Input(shape=(224, 224, 3))
base_model = tf.keras.applications.resnet50.ResNet50(include_top=False,
weights='imagenet',
input_tensor=inputs)
x = tf.keras.layers.Flatten()(base_model.output)
outputs = tf.keras.layers.Dense(100)(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
base_model.trainable = True
model.summary()
initial_lr = 0.001
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=initial_lr),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
initial_epochs = 10
history = model.fit(ds_train,
epochs=initial_epochs,
validation_data=ds_val,
callbacks=[tf.keras.callbacks.LearningRateScheduler(
tf.keras.optimizers.schedules.CosineDecay(initial_learning_rate = initial_lr,
decay_steps = initial_epochs,
alpha = 0.027), # equivalent to dropping learning rate 3 times by factor of 0.3
verbose=1)
]
)
import tensorflow_model_optimization as tfmot
from tensorflow_model_optimization.python.core.clustering.keras.experimental import cluster
cluster_weights = cluster.cluster_weights
CentroidInitialization = tfmot.clustering.keras.CentroidInitialization
clustering_params = {
'number_of_clusters': 16,
'cluster_centroids_init': CentroidInitialization.KMEANS_PLUS_PLUS,
# 'cluster_per_channel': True
}
model_for_clustering = cluster_weights(model, **clustering_params)
lr = 0.027 * initial_lr # final lr of initial training above
model_for_clustering.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=lr),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model_for_clustering.summary()
model_for_clustering.fit(ds_train,
epochs=initial_epochs,
validation_data=ds_val)
stripped_clustered_model = tfmot.clustering.keras.strip_clustering(model_for_clustering)
def print_model_weight_clusters(model, pre_layer_name=""):
for layer in model.layers:
if hasattr(layer, 'layers'):
if pre_layer_name == "":
pre_layer_name = layer.name
else:
pre_layer_name = '{}/{}'.format(pre_layer_name, layer.name)
print_model_weight_clusters(layer, pre_layer_name)
pre_layer_name = ""
continue
if isinstance(layer, tf.keras.layers.Wrapper):
weights = layer.trainable_weights
else:
weights = layer.weights
for weight in weights:
# ignore auxiliary quantization weights
if "quantize_layer" in weight.name:
continue
if "kernel" in weight.name:
unique_count = len(np.unique(weight))
if pre_layer_name == "":
print(f"{layer.name}/{weight.name}: {unique_count} clusters ")
else:
print(f"{pre_layer_name}/{layer.name}/{weight.name}: {unique_count} clusters ")
print_model_weight_clusters(stripped_clustered_model)
# CQAT
quant_aware_annotate_model = tfmot.quantization.keras.quantize_annotate_model(
stripped_clustered_model)
cqat_model = tfmot.quantization.keras.quantize_apply(
quant_aware_annotate_model,
tfmot.experimental.combine.Default8BitClusterPreserveQuantizeScheme(preserve_sparsity=False))
lr = 0.027 * initial_lr
cqat_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=lr),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
cqat_model.summary()
subset_size = 2000
subset_ds_train = ds_train.unbatch()\
.shuffle(num_train)\
.take(subset_size)\
.shuffle(subset_size)\
.batch(batch_size)\
.prefetch(tf.data.AUTOTUNE)
history = cqat_model.fit(subset_ds_train,
epochs=initial_epochs,
validation_data=ds_val)
print_model_weight_clusters(cqat_model)
Screenshots If applicable, add screenshots to help explain your problem.
Additional context Add any other context about the problem here.
It might be a bug due to it's experimental. @MatteoArm Do you have any idea? Thanks!