addons icon indicating copy to clipboard operation
addons copied to clipboard

Failed to model.save when use tfa.seq2seq.BahdanauAttention and tfa.seq2seq.AttentionWrapper

Open yang-stressfree opened this issue 3 years ago • 3 comments

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): test on macOS 12.2 and Ubuntu 18.04.6 LTS
  • TensorFlow version and how it was installed (source or binary): 2.7.0 binary(conda pip)
  • TensorFlow-Addons version and how it was installed (source or binary): 0.15.0 binary(conda pip)
  • Python version: 3.8
  • Is GPU used? (yes/no): test with GPU and without GPU

Describe the bug

If my decoder model include tfa.seq2seq.BahdanauAttention and tfa.seq2seq.AttentionWrapper layers, the model will failed to save;

Otherwise, if remove tfa.seq2seq.BahdanauAttention and tfa.seq2seq.AttentionWrapper layers from the model, the model can save successfully.

Code to reproduce the issue

I create a colab note to reproduce this issue:

https://colab.research.google.com/gist/yang-stressfree/bc78b67ca6f051fe60a7e863b99cc1b3#scrollTo=0jcDbzxVD4-h

import tensorflow as tf
import tensorflow_addons as tfa
import tensorflow.keras as keras


class TinyDemoModel(keras.Model):
    def __init__(self, rnn_units, *args, **kwargs):
        super(TinyDemoModel, self).__init__(*args, **kwargs)
        self.vocab_size = 10
        self.rnn_units = rnn_units
        self.dec_embedding = keras.layers.Embedding(input_dim=self.vocab_size, output_dim=self.vocab_size)
        self.dec_lstm_cell = keras.layers.LSTMCell(units=self.rnn_units)
        self.dec_attention = tfa.seq2seq.BahdanauAttention(units=self.rnn_units)
        self.dec_rnn_cell = tfa.seq2seq.AttentionWrapper(cell=self.dec_lstm_cell,
                                                         attention_mechanism=self.dec_attention,
                                                         attention_layer_size=self.rnn_units)
        self.dec_fc = keras.layers.Dense(self.vocab_size)
        self.dec_train = tfa.seq2seq.BasicDecoder(self.dec_rnn_cell, tfa.seq2seq.sampler.TrainingSampler(),
                                                  output_layer=self.dec_fc)

    def dec_build_initial_state(self, batch_size, enc_h, enc_c):
        initial_state = self.dec_rnn_cell.get_initial_state(batch_size=batch_size, dtype=enc_h.dtype)
        initial_state = initial_state.clone(cell_state=[enc_h, enc_c])
        return initial_state

    def get_config(self):
        raise NotImplementedError

    def call(self, inputs, training=None, mask=None):
        batch_seq_encoded, batch_seq_labeled = inputs
        shape_batch_seq_labeled = tf.shape(batch_seq_labeled)
        batch_size = shape_batch_seq_labeled[0]
        self.dec_attention.setup_memory(batch_seq_encoded)
        initial_state = self.dec_build_initial_state(batch_size, tf.zeros([batch_size, self.rnn_units]),
                                                     tf.zeros([batch_size, self.rnn_units]))
        batch_seq_labeled_embedded = self.dec_embedding(batch_seq_labeled)
        output, _, _ = self.dec_train(batch_seq_labeled_embedded, initial_state=initial_state)
        # pad and return
        batch_seq_predicted_odds = output.rnn_output
        pad_size = shape_batch_seq_labeled[1] - tf.shape(batch_seq_predicted_odds)[1]
        batch_seq_predicted_odds = tf.pad(batch_seq_predicted_odds, [[0, 0], [0, pad_size], [0, 0]])
        return batch_seq_predicted_odds


def train_and_save():
    rnn_units = 8
    batch_size = 2
    batch_seq_encoded = tf.ones([batch_size, 100, rnn_units], dtype=tf.float32)
    batch_seq_labeled = tf.ones([batch_size, 100], dtype=tf.int32)
    model = TinyDemoModel(rnn_units)
    loss_obj = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    model.compile(optimizer="adam", loss=loss_obj)
    model.fit(x=(batch_seq_encoded, batch_seq_labeled), y=batch_seq_labeled)
    model.save(filepath="/tmp/saved_model_tiny_demo_model")


if __name__ == "__main__":
    train_and_save()

Run this file, will get:

Traceback (most recent call last):
  File "tiny_model.py", line 61, in <module>
    train_and_save()
  File "tiny_model.py", line 57, in train_and_save
    model.save(filepath="/tmp/saved_model_tiny_demo_model")
  File "lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 67, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "lib/python3.9/site-packages/tensorflow/python/saved_model/save.py", line 402, in map_resources
    raise ValueError(
ValueError: Unable to save function b'__inference_tiny_demo_model_layer_call_fn_6516' because it captures graph tensor Tensor("BahdanauAttention/strided_slice:0", shape=(), dtype=int32) from a parent function which cannot be converted to a constant with `tf.get_static_value`.

Other info / logs

Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.

yang-stressfree avatar Feb 21 '22 06:02 yang-stressfree

Similar issue https://github.com/tensorflow/addons/issues/2672

yang-stressfree avatar Feb 21 '22 06:02 yang-stressfree

Test with:

tensorflow==2.8.0
tensorflow-addons==0.16.1

get the same error 😢

yang-stressfree avatar Feb 21 '22 07:02 yang-stressfree

I figure out why this error occur when try to save the model with tf.saved_model.save or model.save:

    def call(self, inputs, training=None, mask=None):
        batch_seq_encoded, batch_seq_labeled = inputs
        # ...
        self.dec_attention.setup_memory(batch_seq_encoded)

actually the error explain itself:

ValueError: Unable to save function b'__inference_tiny_demo_model_layer_call_fn_6516' because it captures graph tensor Tensor("BahdanauAttention/strided_slice:0", shape=(), dtype=int32) from a parent function which cannot be converted to a constant with tf.get_static_value.

Because tf.saved_model.save try to convert batch_seq_encoded which is a dynamic value as memory of attention mechanism to a static value.

For anyone want to export model by tf.saved_model.save, DO NOT implement attention mechanism with tfa.seq2seq.BahdanauAttention and tfa.seq2seq.AttentionWrapper.

As a beginner, https://www.tensorflow.org/text/tutorials/nmt_with_attention is a workable and portable solution.

yang-stressfree avatar Feb 28 '22 13:02 yang-stressfree