keras icon indicating copy to clipboard operation
keras copied to clipboard

get_weights()/set_weights() take too long

Open MatthewWiens101 opened this issue 3 years ago • 6 comments

System information.

  • TensorFlow version: 2.3.0
  • CUDA version: 10.1
  • GPU compute capability: 5.2

Describe the problem.

I am running some code which repeatedly (every training iteration) calls layer.get_weights() and layer.set_weights(). The callback operation containing these calls takes 0.009ms compared to the 0.003ms taken to run the batch and as such more than triples the training time required. I assume that this operation is simply moving tensors around (should be only on GPU) and thus should not take time comparable to the large matrix multiplications occurring during the batch iteration. I have reviewed the source code and to the best of my understanding this is what is happening. However, it is obviously taking an extraordinarily long time. Does anyone have any idea why this happens, or any approaches to reduce the time taken to call set_weights() and get_weights()? This abnormally long runtime may be due to the structure of the get_weights()/set_weights() functions, which is why I am raising this issue as a bug.

My intuition is that it may be due to data being sent to the CPU and back, or converted from tensors to numpy. Or, perhaps, upon calling set_weights, tensorflow rebuilds the entire graph from scratch or something similar.

One thing I noticed is that keras has their own pruning functionality shown here and this functionality incidentally also has a long callback runtime (see below). Perhaps this is related?

3/422 [..............................] - ETA: 12s - loss: 0.0628 - accuracy: 0.9896  
WARNING:tensorflow:Callback method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0075s vs `on_train_batch_end` time: 0.0076s). Check your callbacks.

Describe the current behavior.

The callback to on_train_batch_end() in the code below calls get_weights() twice and set_weights() once, and takes twice as long to run as the batch update:

Epoch 1/40
  1/629 [..............................] - ETA: 0s - loss: 2.3555 - accuracy: 0.0000e+00
  WARNING:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0030s vs `on_train_batch_end` time: 0.0090s). Check your callbacks.

This is explicitly due to calling get_weights() and set_weights(), as their removal from the callback reduces runtime of the callback to negligible amounts.

Describe the expected behavior.

Ideally, I would like to achieve iterative magnitude pruning with the lowest possible runtime.

Standalone code to reproduce the issue.

import sys
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import OneHotEncoder
from sklearn.model_selection import train_test_split
from tensorflow import keras
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.callbacks import Callback
from tensorflow.keras.datasets import mnist

### PRUNE WEIGHTS CALLBACK ###
class pruneModelCallback(Callback):
    def __init__(self, init_weight_dict=None, mask_dict=None):
        self.n_batches = 0
        self.init_weight_dict = init_weight_dict
        self.mask_dict = mask_dict
    
    def on_train_batch_begin(self, batch, logs=None):
        # save weights at initialization
        if self.n_batches == 0:
            if self.init_weight_dict is not None:
                for layer_i in range(len(self.model.layers)):
                    w = self.init_weight_dict['w_'+str(layer_i+1)]
                    b = self.init_weight_dict['b_'+str(layer_i+1)]

                    self.model.layers[layer_i].set_weights([w,b])
            else:
                self.init_weight_dict = {}
                for layer_i in range(len(self.model.layers)):
                    w = self.model.layers[layer_i].get_weights()[0]
                    b = self.model.layers[layer_i].get_weights()[1]

                    self.init_weight_dict['w_'+str(layer_i+1)] = w
                    self.init_weight_dict['b_'+str(layer_i+1)] = b
        
        self.n_batches = self.n_batches + 1
        
    # This is the problematic function, runs every training iteration batch
    def on_train_batch_end(self, batch, logs=None):
        # zero out pruned weights
        if self.mask_dict is not None:
            for layer_i in range(len(self.model.layers)):
                # removing these slightly improves runtime
                w = self.model.layers[layer_i].get_weights()[0]
                b = self.model.layers[layer_i].get_weights()[1]

                w_mask = self.mask_dict['w_'+str(layer_i+1)]

                # this multiplication takes no time comparably and removing it 
                # does not influence time taken
                w_pruned = w * w_mask

                # removing this function call significantly speeds up the runtime
                self.model.layers[layer_i].set_weights([w_pruned,b])

class pruneWeights():
    def __init__(self, model, percentile, pruning_type="IMP"):
        # generate pruned mask
        if pruning_type == "IMP":
            return self._IMP(model, percentile)
        else:
            raise ValueError("Unknown pruning_type {}".format(pruning_type))
    
    def _IMP(self, model, percentile):
        mask_dict = {}
        w_list = None
        for layer_i in range(len(model.layers)):
            w = model.layers[layer_i].get_weights()[0]
            w_shape = tf.shape(w)
            full_shape = tf.math.reduce_prod(w_shape)
            w_flat = tf.reshape(w, full_shape)
            if w_list is None:
                w_list = w_flat
            else:
                w_list = tf.concat([w_list, w_flat], axis=0)
        w_list = tf.math.abs(w_list)
        thresh = tfp.stats.percentile(w_list, percentile*100)
        test_mask = tf.cast(tf.math.greater(w_list, thresh), tf.float32)
        for layer_i in range(len(model.layers)):
            w = model.layers[layer_i].get_weights()[0]
            w = tf.math.abs(w)
            mask = tf.cast(tf.math.greater(w, thresh), tf.float32)
            mask_dict['w_'+str(layer_i+1)] = mask
        self.mask_dict = mask_dict

def pruning_breakdown(mask_dict, model):
    for layer_i in range(len(model.layers)):
        mask = mask_dict['w_'+str(layer_i+1)]
        print("w_"+str(layer_i+1)+": "+str(tf.math.reduce_mean(mask)))

def main():
    ### LOAD MNIST DATASET ###

    (x_train , y_train), (x_test , y_test) = mnist.load_data()
    x = np.concatenate((x_train, x_test), axis=0)
    y = np.concatenate((y_train, y_test), axis=0)
    x= x.astype("float32") / 255
    x= np.reshape(x, (np.shape(x)[0], 784))
    scaler = StandardScaler(with_std=False)
    scaler.fit(x)
    x_t= scaler.transform(x)
    ohe = OneHotEncoder()
    y_t= ohe.fit_transform(y.reshape(-1, 1)).toarray()
    input_dim = np.shape(x_t)[1:]
    output_dim = np.shape(y_t)[1:]

    del x_train, x_test, y_train, y_test, x, y

    ### SPLIT TRAINING DATA ###

    X_train, X_test, y_train, y_test = train_test_split(x_t, y_t, test_size=0.33, random_state=42)

    idxs = tf.range(tf.shape(X_train)[0])

    ### MODEL INIT ###

    model = Sequential([
        Dense(300, input_dim=input_dim[0], activation='relu'),
        Dense(100, activation='relu'),
        Dense(50, activation='relu'),
        Dense(output_dim[0], activation='softmax')
    ])

    model.compile(
        optimizer = keras.optimizers.Adam(lr=1.2e-4),
        loss = tf.keras.losses.CategoricalCrossentropy(),
        metrics = ['accuracy']
    )

    ### TRAIN MODEL ###

    epochs = 20

    pr = pruneModelCallback()
    es = tf.keras.callbacks.EarlyStopping(patience=5, restore_best_weights=True,)

    history = model.fit(
        x=X_train,
        y=y_train,
        batch_size=64,
        epochs=epochs,
        validation_data=(X_test,y_test),
        callbacks = [pr, es],
    )

    ### ITERATIVELY PRUNE MODEL ###

    percentile_per_it = 0.20
    it = 14

    for i in range(it):
        total_percentile = 1-tf.math.pow(1-percentile_per_it, i+1)
        print(total_percentile)
        pruned_weights = pruneWeights(model, total_percentile)
        pruning_breakdown(pruned_weights.mask_dict, model)

        model = Sequential([
            Dense(300, input_dim=input_dim[0], activation='relu'),
            Dense(100, activation='relu'),
            Dense(50, activation='relu'),
            Dense(output_dim[0], activation='softmax')
        ])

        model.compile(
            optimizer = keras.optimizers.Adam(lr=1.2e-4),
            loss = tf.keras.losses.CategoricalCrossentropy(),
            metrics = ['accuracy']
        )

        ### TRAIN MODEL ###

        epochs = 40

        pr_next = pruneModelCallback(init_weight_dict=pr.init_weight_dict, mask_dict=pruned_weights.mask_dict)
        es = tf.keras.callbacks.EarlyStopping(patience=5, restore_best_weights=True,)

        history = model.fit(
            x=X_train,
            y=y_train,
            batch_size=64,
            epochs=epochs,
            validation_data=(X_test,y_test),
            callbacks = [pr_next, es],
        )

        pr = pr_next

if __name__ == "__main__":
    main()

MatthewWiens101 avatar Jun 20 '22 20:06 MatthewWiens101

@MatthewWiens101 TF v2.3 is not actively supported, we recommend you to kindly upgrade to latest TF version. I tried to replicate this issue on colab, could you please find the gist here and confirm the same? Thank you!

sushreebarsa avatar Jun 22 '22 13:06 sushreebarsa

@sushreebarsa Sorry, there were some typos in the shared code. I have updated the gist here and it is running fine in TF v2.8. It is still producing the issue with the warning:

5/733 [..............................] - ETA: 11s - loss: 1.4617 - accuracy: 0.7594 
WARNING:tensorflow:Callback method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0024s vs `on_train_batch_end` time: 0.0122s). Check your callbacks.

MatthewWiens101 avatar Jun 22 '22 18:06 MatthewWiens101

Are you satisfied with the resolution of your issue? Yes No

google-ml-butler[bot] avatar Jun 22 '22 18:06 google-ml-butler[bot]

@MatthewWiens101 Could you have look at this gist and let us know if it is the current behaviour of the issue reported ? Thank you!

sushreebarsa avatar Jun 24 '22 10:06 sushreebarsa

@sushreebarsa Yes the behavior seen in that gist matches the issue I have reported.

MatthewWiens101 avatar Jun 24 '22 16:06 MatthewWiens101

@sachinprasadhs any update on this issue? Still looking for a faster workaround.

MatthewWiens101 avatar Sep 09 '22 21:09 MatthewWiens101

I have the same issue with tf.GradientTape(). I use this to watch the gradients on_epoch_end and each iteration takes around 50 minutes, while training itself is less than 5 minutes.

azd-rzzd avatar Nov 15 '22 13:11 azd-rzzd

Looking at the code here, it seems like you have a mask_dict which is static in the context of an individual model.fit() call. Is that right?

If that is the case, you would probably see much better performance by making a custom layer called MaskedDense perhaps, that implements the logic you want here, and passing the static mask to that layer. Hard to say exactly what the right structure would be without digging more into the use case, but the overall goal should be to remove the on_train_batch_end and make simple layers that do everything you need inside of call.

In general, Keras will achieve best performance with your model when compiling everything into a tf.function. This guide might be a useful reference. You don't need to do anything fancy to get this working with Keras. Just make a model, compile it as normal, and you are running with tf.function under the hood.

However, attempting to override the weights for every layer eagerly between each train step of your model (the compiled fast part), will be way slower than brining this w_mask logic into the actual compiled train step of your model.

Hope that helps!

mattdangerw avatar Nov 17 '22 00:11 mattdangerw