model-optimization icon indicating copy to clipboard operation
model-optimization copied to clipboard

Pruning: improve custom training loop API

Open alanchiao opened this issue 4 years ago • 5 comments

The recommended path for pruning with a custom training loop is not as simple as it could be.

pruned_model = setup_pruned_model()

loss = tf.keras.losses.categorical_crossentropy
optimizer = keras.optimizers.Adam()

log_dir = tempfile.mkdtemp()

# This is all not boilerplate.
pruned_model.optimizer = optimizer
step_callback = tfmot.sparsity.keras.UpdatePruningStep()
step_callback.set_model(pruned_model)
log_callback = tfmot.sparsity.keras.PruningSummaries(log_dir=log_dir) # optional Tensorboard logging.
log_callback.set_model(pruned_model)

step_callback.on_train_begin()
for _ in range(3):
    # only one batch given batch_size = 20 and input shape.
    step_callback.on_train_batch_begin(batch=unused_arg)
    inp = np.reshape(x_train,
                     [self._BATCH_SIZE, 10])  # original shape: from [10].
    with tf.GradientTape() as tape:
      logits = pruned_model(inp, training=True)
      loss_value = loss(y_train, logits)
      grads = tape.gradient(loss_value, pruned_model.trainable_variables)
      optimizer.apply_gradients(zip(grads, pruned_model.trainable_variables))

    step_callback.on_epoch_end(batch=unused_arg)
    log_callback.on_epoch_end(batch=unused_arg)
...

The set_model and pruned_model.optimizer setting is unusual and could be missed.

alanchiao avatar Feb 27 '20 02:02 alanchiao

I was not sure if this should be a separate issue but I found that pruning does not work unless we explicitly add training=True argument in the model call. (This seemed counterintuitive to me since from my understanding as I used to think by default training is set to be true in keras Model.) In any case, it would likely be helpful if you specifically point out about the training argument.

Naman-ntc avatar Jul 27 '20 09:07 Naman-ntc

Any updates on this to fully support custom training loop and pruning via fit function? @alanchiao does the example code you provided work to prune model with custom training loop?

aqibsaeed avatar Nov 20 '21 14:11 aqibsaeed

Hi @aqibsaeed ,

The pruning via fit function is currently available - https://www.tensorflow.org/model_optimization/guide/pruning/pruning_with_keras

I think the code above from Alan is working for any custom training loop.

rino20 avatar Nov 26 '21 02:11 rino20

Yes, it works thanks.

aqibsaeed avatar Nov 26 '21 07:11 aqibsaeed