model-optimization
model-optimization copied to clipboard
QAT with trainable=False does not work as expected.
Describe the bug
System information
TensorFlow version (installed from source or binary): 2.6.0
TensorFlow Model Optimization version (installed from source or binary): 0.7.0
Python version: 3.7.10
Describe the expected behavior
After setting trainable=False
on layers with a quantisation wrapper applied, the weights in that layer should not change during training.
Describe the current behavior The loss decreases during training even if all layers are set to be non-trainable.
Code to reproduce the issue
import numpy as np
import tensorflow as tf
import tensorflow_model_optimization as tfmot
A = np.random.uniform(size=(10000, 10, 10))
print('Expected behaviour')
inp = tf.keras.Input(shape=(10, 10), batch_size=10)
out = tf.keras.layers.Dense(10)(inp)
model = tf.keras.Model(inp, out)
for layer in model.layers:
layer.trainable = False
model.compile(loss='mse')
model.fit(A, A, batch_size=10, epochs=5)
print('{} trainable weights'.format(len(model.layers[1].trainable_weights)))
print('\nQuantised behaviour')
inp = tf.keras.Input(shape=(10, 10), batch_size=10)
out = tfmot.quantization.keras.quantize_annotate_layer(tf.keras.layers.Dense(10))(inp)
quant_model = tfmot.quantization.keras.quantize_apply(tf.keras.Model(inp, out))
for layer in quant_model.layers:
layer.trainable = False
quant_model.compile(loss='mse')
quant_model.fit(A, A, batch_size=10, epochs=5)
print('{} trainable weights'.format(len(quant_model.layers[2].trainable_weights)))
@Xhark Can you take a look on this?
I'm also encountering this problem. Part of the problem appears to be that QuantizeWrapper
doesn't account for self.trainable == False
in trainable_weights
and non_trainable_weights
(c.f. tf.keras.layers.Layer
, which does account for this).
However, I've patched this as per below, and it still doesn't appear to stop the weights in these layers from training. Furthermore, I'm not sure how to stop quantization parameters from training (they're always non-trainable weights, but obviously update based on the input data to the layer, and I'd like to stop them from changing as well).
from tensorflow_model_optimization.python.core.quantization.keras import quantize_wrapper
QuantizeWrapper = quantize_wrapper.QuantizeWrapper
def trainable_weights(self):
if self.trainable:
return self.layer.trainable_weights + self._trainable_weights
else:
return []
def non_trainable_weights(self):
if self.trainable:
return self.layer.non_trainable_weights + self._non_trainable_weights
else:
# Return layer weights first, and previously trainable before non-trainable,
# to maintain the order in `self.weights`.
# TODO: This won't be the correct order if `self.layer` has weights that are
# always not trainable. To get the correct order, we would have to use
# `layer._trainable_weights` and `layer._non_trainable_weights`, or switch
# over to using QuantizeWrapperV2
return (
self.layer.trainable_weights +
self.layer.non_trainable_weights +
self._trainable_weights +
self._non_trainable_weights
)
QuantizeWrapper.trainable_weights = property(trainable_weights)
QuantizeWrapper.non_trainable_weights = property(non_trainable_weights)
EDIT: I got this patch working by patching QuantizeWrapperV2
as well, see below.
I've started work on a patch here: https://github.com/hunse/model-optimization/pull/1. It works for ensuring that the layer's trainable parameters (e.g. kernel, bias) do not get trained when trainable=False
on the layer. @metinsuloglu, it works to get your example above pretty much passing, specifically having the trainable_weights
be empty and having the loss stay almost exactly constant.
What it doesn't do yet is stop the quantization parameters from changing; for that reason, the loss does fluctuate slightly in your example. To freeze quantization parameters, we would probably want a different interface to set them to be non-adjustable, since they are already considered non-trainable weights and aren't updated as part of the backprop pass (which is how trainable weights are updated), so it doesn't make sense that setting trainable=False
would freeze them.
OK, I've added a test to https://github.com/hunse/model-optimization/pull/1. Let me know if/when I can make a PR here (since the CONTRIBUTING.md document says not to make a PR until an issue is marked as "contributions welcome").
@xhae @Xhark @fredrec any updates on this?
@xhae @Xhark @fredrec any updates on this?
I'm having the problem where as soon as I quantise my model the trainable settings are reset and I can't modify them back