model-optimization icon indicating copy to clipboard operation
model-optimization copied to clipboard

Pruning fails for siamese networks

Open cacfd3a opened this issue 5 years ago • 4 comments

Describe the bug Follow along with MNIST siamese, where one set of weights is used twice in the same network. Try to make one layer prune, get error: tensorflow.python.framework.errors_impl.InvalidArgumentError: assertion failed: [Prune() wrapper requires the UpdatePruningStep callback to be provided during training. Please add it as a callback to your model.fit call.] [Condition x >= y did not hold element-wise:x (assert_greater_equal/ReadVariableOp:0) = ]. Even though you supply the pruning callback to fit():

System information

TensorFlow installed from (source or binary): binary

TensorFlow version: 2.1.0

TensorFlow Model Optimization version: 0.2.1

Python version: 3.6.10

Describe the expected behavior Script doesn't crash

Describe the current behavior Script crashes

Code to reproduce the issue

import numpy as np
import tensorflow as tf

from tensorflow.keras.datasets import mnist
from tensorflow import Variable, float32
from tensorflow.keras.layers import Input, Flatten, Dense, Dropout, Lambda, Conv2D, MaxPooling2D
from tensorflow.keras.models import Model
from tensorflow_model_optimization.sparsity import keras as sparsity

def euclidean_distance(vects):
    x, y = vects
    sum_square = keras_backend.sum(keras_backend.square(x - y), axis=1, keepdims=True)
    return keras_backend.sqrt(keras_backend.maximum(sum_square, keras_backend.epsilon()))


def eucl_dist_output_shape(shapes):
    shape1, shape2 = shapes
    return shape1[0], 1

def random_other_digit(digit, max_num):
    inc = random.randrange(1, max_num)
    other_digit = (digit + inc) % max_num
    return other_digit

def create_pairs(x, digit_indices_in_dataset, num_classes=10): 
    """Massage input from MNIST
    """
    pairs = []
    labels = []
    minimal_digit_set_size = min([len(digit_indices_in_dataset[d]) for d in range(num_classes)]) - 1
    if minimal_digit_set_size <= 0:
        raise ValueError("Impossible ", minimal_digit_set_size)
    for digit in range(num_classes):
        for i in range(minimal_digit_set_size):
            indices_for_digit = digit_indices_in_dataset[digit]

            index_for_digit = indices_for_digit[i]
            digit_image_1 = x[index_for_digit]

            digit_index_same = indices_for_digit[i + 1]
            digit_image_same = x[digit_index_same]
            pair_same_digits = [digit_image_1, digit_image_same]
            pairs.append(pair_same_digits)

            other_digit = random_other_digit(digit, num_classes)
            digit_image_other = x[digit_indices_in_dataset[other_digit][i]]
            pair_different_digits = [digit_image_1, digit_image_other]
            pairs.append(pair_different_digits)

            # [Same, Different]
            labels += [1, 0]
    return np.array(pairs), np.array(labels)

(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = x_train[0:5000]
y_train = y_train[0:5000]

x_test = x_test[0:5000]
y_test = y_test[0:5000]

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')

x_train /= 255
x_test /= 255
input_shape = x_train.shape[1:]

# create training+test positive and negative pairs
digit_indices = [np.where(y_train == i)[0] for i in range(num_classes)]

tr_pairs, tr_y = create_pairs(x_train, digit_indices, num_classes)
print("Pairs created")

print("Creating testing pairs...")
digit_indices = [np.where(y_test == i)[0] for i in range(num_classes)]
te_pairs, te_y = create_pairs(x_test, digit_indices, num_classes)
def create_base_network_pruned(input_shape, begin_step, end_step):
    """Base network to be shared (eq. to feature extraction).
    """
    pruning_params = {
        'pruning_schedule': sparsity.PolynomialDecay(initial_sparsity=0.50,
                                                     final_sparsity=0.90,
                                                     begin_step=begin_step,
                                                     end_step=end_step,
                                                     frequency=100)
    }
    input_base = Input(shape=input_shape)
    x = Flatten()(input_base)
    x = Dense(128, activation='relu')(x)
    x = Dropout(0.1)(x)
    x = sparsity.prune_low_magnitude(Dense(128, activation='relu'),
                                     **pruning_params)(x)
    x = Dropout(0.1)(x)
    x = Dense(128, activation='relu')(x)
    return Model(input_base, x)

def siamese_dense_pruned(input_shape, begin_step, end_step):
    base_network = create_base_network_pruned(input_shape, begin_step, end_step)
    input_a = Input(shape=input_shape)
    input_b = Input(shape=input_shape)
    # because we re-use the same instance `base_network`,
    # the weights of the network
    # will be shared across the two branches
    processed_a = base_network(input_a)
    processed_b = base_network(input_b)
    distance = Lambda(euclidean_distance,
                      output_shape=eucl_dist_output_shape)([processed_a, processed_b])
    return Model([input_a, input_b], distance)

epochs = 20
batch_size = 128
num_train_samples = x_train.shape[0]
end_step = np.ceil(1.0 * num_train_samples / batch_size).astype(np.int32) * epochs
begin_step = end_step / 5
model = siamese_dense_pruned(input_shape, begin_step, end_step)
rms = RMSprop()
model.compile(loss=contrastive_loss, optimizer=rms, metrics=[accuracy])
callbacks = [
       sparsity.UpdatePruningStep(),
       sparsity.PruningSummaries(log_dir=logdir, profile_batch=0)
]  
model.fit(training_pairs, tr_y,
              batch_size=batch_size,
              epochs=epochs,
              verbose=1
              validation_data=(testing_pairs, te_y),
              callbacks=callbacks)

Additional context Might have something to do with referencing the same layers twice in a model.

cacfd3a avatar Feb 26 '20 09:02 cacfd3a

@liyunlu0618 : do you have time to look into this now? Otherwise, I'll get back to this in 2 - 3 weeks.

alanchiao avatar Feb 26 '20 23:02 alanchiao

@digitalheir : will not be able to get to this right now. If this blocks you (and you're not able to make use of pruning in other ways), please let us know.

alanchiao avatar Feb 28 '20 22:02 alanchiao

That's fine. Pruning is not a blocking issue for me. I wanted to optimize my model for mobile and got good results too with just quantization. Thanks for the swift response!

cacfd3a avatar Feb 29 '20 09:02 cacfd3a

Hello, is there a way to solve that problem? We can't even make the following tutorial work https://github.com/tensorflow/model-optimization/blob/master/tensorflow_model_optimization/g3doc/guide/pruning/pruning_with_keras.ipynb I would be interested in pruning. Thanks!

ansacaron avatar Mar 07 '24 14:03 ansacaron