model-optimization
model-optimization copied to clipboard
Pruning: improve custom training loop API
The recommended path for pruning with a custom training loop is not as simple as it could be.
pruned_model = setup_pruned_model()
loss = tf.keras.losses.categorical_crossentropy
optimizer = keras.optimizers.Adam()
log_dir = tempfile.mkdtemp()
# This is all not boilerplate.
pruned_model.optimizer = optimizer
step_callback = tfmot.sparsity.keras.UpdatePruningStep()
step_callback.set_model(pruned_model)
log_callback = tfmot.sparsity.keras.PruningSummaries(log_dir=log_dir) # optional Tensorboard logging.
log_callback.set_model(pruned_model)
step_callback.on_train_begin()
for _ in range(3):
# only one batch given batch_size = 20 and input shape.
step_callback.on_train_batch_begin(batch=unused_arg)
inp = np.reshape(x_train,
[self._BATCH_SIZE, 10]) # original shape: from [10].
with tf.GradientTape() as tape:
logits = pruned_model(inp, training=True)
loss_value = loss(y_train, logits)
grads = tape.gradient(loss_value, pruned_model.trainable_variables)
optimizer.apply_gradients(zip(grads, pruned_model.trainable_variables))
step_callback.on_epoch_end(batch=unused_arg)
log_callback.on_epoch_end(batch=unused_arg)
...
The set_model and pruned_model.optimizer setting is unusual and could be missed.
I was not sure if this should be a separate issue but I found that pruning does not work unless we explicitly add training=True
argument in the model call. (This seemed counterintuitive to me since from my understanding as I used to think by default training is set to be true in keras Model.)
In any case, it would likely be helpful if you specifically point out about the training argument.
Any updates on this to fully support custom training loop and pruning via fit function? @alanchiao does the example code you provided work to prune model with custom training loop?
Hi @aqibsaeed ,
The pruning via fit function is currently available - https://www.tensorflow.org/model_optimization/guide/pruning/pruning_with_keras
I think the code above from Alan is working for any custom training loop.
Yes, it works thanks.