Cannot save pruned model with MultiHeadAttention Layer
Describe the bug
Trying to save a model that wraps a MultiHeadAttention layer in a PruneLowMagnitude, fails with duplicate dataset name.
System information
TensorFlow version (installed from source or binary): 2.13.0rc1
TensorFlow Model Optimization version (installed from source or binary): 0.7.5
Python version: 3.10
Describe the expected behavior Successful model save.
Describe the current behavior
When saving a pruned model, I get a ValueError: Unable to create dataset (name already exists) on "mask:0".
Code to reproduce the issue
import tensorflow as tf
import tensorflow_model_optimization as tfmot
import tempfile
if __name__ == '__main__':
# model
inputs = tf.keras.layers.Input(shape=(28, 28, 3))
x = tf.keras.layers.Conv2D(filters=128, kernel_size=3, activation='relu')(inputs)
x = tf.keras.layers.MultiHeadAttention(num_heads=4, key_dim=128)(query=x, value=x, key=x)
outputs = tf.keras.layers.Flatten()(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
model.compile(optimizer='adam', loss='mse')
# call model to initialize weights
model(tf.ones((1, 28, 28, 3)))
# prune model
pruning_params = {
'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(
initial_sparsity=0.5,
final_sparsity=0.9,
begin_step=0,
end_step=1,
frequency=1,
),
}
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(model, **pruning_params)
with tempfile.TemporaryDirectory() as temp_dir:
model_for_pruning.save(temp_dir + '/model.h5') # <-- fails
Potentially related to #661 and #944.
Tagging @Xhark as you have worked on similar issues in the past
@Xhark Could you follow-up on this one?