model-optimization icon indicating copy to clipboard operation
model-optimization copied to clipboard

Support for Recurrent layers for Quantization Aware Training.

Open parth-desai opened this issue 1 year ago • 1 comments

System information

  • TensorFlow version (you are using): 2.15
  • Are you willing to contribute it (Yes/No): Yes

Motivation I am trying to train RNN model with quantization aware training for embedded devices.

Describe the feature I am looking for a way to train with default 8bit weights & activations quantization using quantize_apply API without passing in custom config.

Describe how the feature helps achieve the use case

Describe how existing APIs don't satisfy your use case (optional if obvious)

I tried to use quantize_apply API but I received this error. RuntimeError: Layer gru:<class 'keras.src.layers.rnn.gru.GRU'> is not supported. You can quantize this layer by passing a `tfmot.quantization.keras.QuantizeConfig` instance to the `quantize_annotate_layer` API.

After using quantize_annotate_layer, I was able to train the model but Model fails to save with following error:

  keras.models.save_model(model, filepath=model_filename, save_format="h5")
Traceback (most recent call last):
  File "/workspaces/project-embedded/syntiant-ndp-model-converter/examples/train_audio_model.py", line 169, in <module>
    keras.models.save_model(model, filepath=model_filename, save_format="h5")
  File "/home/vscode/tf_venv/lib/python3.10/site-packages/keras/src/saving/saving_api.py", line 167, in save_model
    return legacy_sm_saving_lib.save_model(
  File "/home/vscode/tf_venv/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 70, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/home/vscode/tf_venv/lib/python3.10/site-packages/h5py/_hl/group.py", line 183, in create_dataset
    dsid = dataset.make_new_dset(group, shape, dtype, data, name, **kwds)
  File "/home/vscode/tf_venv/lib/python3.10/site-packages/h5py/_hl/dataset.py", line 163, in make_new_dset
    dset_id = h5d.create(parent.id, name, tid, sid, dcpl=dcpl, dapl=dapl)
  File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper
  File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper
  File "h5py/h5d.pyx", line 137, in h5py.h5d.create
ValueError: Unable to synchronously create dataset (name already exists)

I used following QuantizeConfig

class GruQuantizeConfig(tfmot.quantization.keras.QuantizeConfig):
    # Configure how to quantize weights.
    def get_weights_and_quantizers(self, layer):
        return [
            (
                layer.cell.kernel,
                LastValueQuantizer(
                    num_bits=8, symmetric=True, narrow_range=False, per_axis=False
                ),
            ),
            (
                layer.cell.recurrent_kernel,
                LastValueQuantizer(
                    num_bits=8, symmetric=True, narrow_range=False, per_axis=False
                ),
            ),
        ]

    # Configure how to quantize activations.
    def get_activations_and_quantizers(self, layer):
        return [
            (
                layer.cell.activation,
                MovingAverageQuantizer(
                    num_bits=8, symmetric=False, narrow_range=False, per_axis=False
                ),
            ),
            (
                layer.cell.recurrent_activation,
                MovingAverageQuantizer(
                    num_bits=8, symmetric=False, narrow_range=False, per_axis=False
                ),
            ),
        ]

    def set_quantize_weights(self, layer, quantize_weights):
        # Add this line for each item returned in `get_weights_and_quantizers`
        # , in the same order
        layer.cell.kernel = quantize_weights[0]
        layer.cell.recurrent_kernel = quantize_weights[1]

    def set_quantize_activations(self, layer, quantize_activations):
        # Add this line for each item returned in `get_activations_and_quantizers`
        # , in the same order.
        layer.cell.activation = quantize_activations[0]
        layer.cell.recurrent_activation = quantize_activations[1]

    # Configure how to quantize outputs (may be equivalent to activations).
    def get_output_quantizers(self, layer):
        return []

    def get_config(self):
        return {}

I looked at the source code. It seems that the support for RNN is disabled here for some reason.

I was wondering if this can be enabled back?

parth-desai avatar Feb 06 '24 01:02 parth-desai

Thanks for filing this issue, Parth.

As you said, it looks like RNN was disabled as it was unsupported and yet to be verified on TFLite.

We'll be keeping track of this feature request, but please note that LSTM / RNN / GRU varient support is not prioritized at this moment because it is less relevant to today's ML landscapes compared to transformers.

Thanks, Jen

jenriver avatar Apr 16 '24 03:04 jenriver