model-optimization icon indicating copy to clipboard operation
model-optimization copied to clipboard

How to get quantized weights from QAT model?

Open Hackerman28 opened this issue 4 years ago • 11 comments

Hi all. I've recently trained a keras implementation of ssd-keras. I've managed to run QAT training on the model and got desired the accuracy. I wanted to get the quantised weights from the QAT model saved as a H5 model. There's no support or documentation regarding this in the tensorflow website. How can I get the quantised weights from the saved model after QAT? I tried converting it to TFLite but it is not converting due to a custom layer in model definition. So it would be helpful if I can get the quantised weights alone from the saved model. Here's the code snippet for my QAT training. I am using TF 2.3.

`quantize_scope = tfmot.quantization.keras.quantize_scope

def apply_quantization_to_dense(layer): if 'priorbox' in layer.name: return layer

if isinstance(layer,tf.keras.layers.Concatenate) or isinstance(layer, tf.keras.layers.Reshape) or isinstance(layer,tf.keras.layers.Lambda): return layer

return tfmot.quantization.keras.quantize_annotate_layer(layer)

annotated_model = tf.keras.models.clone_model( model, clone_function=apply_quantization_to_dense, )

with quantize_scope({'AnchorBoxes': AnchorBoxes}): quant_aware_model = tfmot.quantization.keras.quantize_apply(annotated_model)

quant_aware_model.summary() quant_aware_model.compile(optimizer=adam, loss=ssd_loss.compute_loss) quant_aware_model.fit_generator(train_generator, epochs=424, steps_per_epoch=1000, callbacks=callbacks, validation_data=val_generator, validation_steps=100, initial_epoch=414)`

Hackerman28 avatar Feb 03 '21 19:02 Hackerman28

Assuming you have implemented the default 8 bit scheme

for layer in keras_model.layers:
  if hasattr(layer, 'quantize_config'):
    for weight, quantizer, quantizer_vars in layer._weight_vars:
        quantized_and_dequantized = quantizer(weight, training=false, weights=quantizer_vars)
        min_var = quantizer_vars['min_var']
        max_var = quantizer_vars['max_var']
        quantized = dequantize(quantize_and_dequantized, min_var, max_var, quantizer)

Where dequantize is specific to the quantizer based on the options there, for example for num_bits = 8, per_axis = False, symmetric = True, narrow_range = False:

quantized_min = -(1 << (quantizer.num_bits - 1))
scale = min_var / quantized_min
quantized = tf.math.round(quantized_and_dequantized / scale).numpy().astype(np.int8)

daverim avatar Mar 19 '21 09:03 daverim

@daverim does TensorFlow have any documentation on how one should specify the dequantize operator for other cases? For example, if I'm using per-tensor, asymmetrical, narrow range, 3 bit quantization, how should I define my scale and zero point variables? I currently have a makeshift implementation where the equations for scale/zp don't have any motivation, it would be nice to see how TF does it.

Further, when obtaining a quantized model in this way by looping through all layers, what should one do with layers that don't have min/max properties, for example Batch-Norm layers?

LucasStromberg avatar Apr 07 '21 10:04 LucasStromberg

The documentation is basically the code -- in our case we use fake_quant_with_min_max_args

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/fake_quant_ops_functor.h we currently default to an 8bit scheme that matches tflite https://www.tensorflow.org/lite/performance/quantization_spec

I actually made a mistake here: should be

quantized = quantize(quantize_and_dequantized, min_var, max_var, quantizer)

The basic idea is the same though. If you want you check your code, dequantize(quantized) == quantize_and_dequantized.

For batchnorm, you should just return the non-quantized weights using layer.get_weights().

daverim avatar May 17 '21 02:05 daverim

Great stuff, thank you.

LucasStromberg avatar May 21 '21 11:05 LucasStromberg

@daverim Thanks for the great explanation above. I had a leading doubt on this, suppose I wanted to fold all the batchnorms in the convolution after qat , this will alter the weights of the kernel. After doing so if I use the min max stored to quantize the weights is this appropriate/similar to Tflite's implementation of batchnorm folding ?

Abhishekvats1997 avatar Aug 17 '21 13:08 Abhishekvats1997

That is right, the stored min max will be incorrect if collected before folding, but used after folding batch norms. It is probably simplest to just get the values after folding and recalculate the min and max (numpy.max(...)). However, if you really want to do this calculation beforehand, the below code is basically how the min and max should be adjusted after folding is handled.

https://github.com/tensorflow/model-optimization/blob/v0.6.0/tensorflow_model_optimization/python/core/quantization/keras/layers/conv_batchnorm.py#L289

daverim avatar Aug 18 '21 01:08 daverim

@daverim Is this kind of folding emulation currently implemented during QAT? I was just wondering will there be a significant loss in accuracy of the network if emulation is not simulated during qat but done post training.

Abhishekvats1997 avatar Aug 18 '21 03:08 Abhishekvats1997

No there is no folding emulation during training any more -- this is deprecated code but still useful in understanding the calculation. This is now calculated during batchnorm folding in the tflite converter, but as you can see in the code example, it is a straight-forward float calculation. The folding itself should not lead to loss of accuracy as it is a linear operation.

daverim avatar Aug 18 '21 09:08 daverim

@daverim What a coincidence lol, the code reference you mentioned the file got removed by Tf today.

Abhishekvats1997 avatar Aug 18 '21 14:08 Abhishekvats1997

Sorry, edited the last comment to point to the current release tag. As I mentioned, that folding is no longer done in TF. However, the calculation is essentially the same in the TFLite converter

daverim avatar Aug 19 '21 02:08 daverim

@daverim Hi, I want to know how to get quantized biases. biases_quantized = tf.math.round(layer.weights[1] / scale).numpy().astype(np.int32) Thank you very much.

JonneryR avatar Apr 04 '23 06:04 JonneryR