addons
addons copied to clipboard
Failed to model.save when use tfa.seq2seq.BahdanauAttention and tfa.seq2seq.AttentionWrapper
System information
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): test on macOS 12.2 and Ubuntu 18.04.6 LTS
- TensorFlow version and how it was installed (source or binary): 2.7.0 binary(conda pip)
- TensorFlow-Addons version and how it was installed (source or binary): 0.15.0 binary(conda pip)
- Python version: 3.8
- Is GPU used? (yes/no): test with GPU and without GPU
Describe the bug
If my decoder model include tfa.seq2seq.BahdanauAttention and tfa.seq2seq.AttentionWrapper layers, the model will failed to save;
Otherwise, if remove tfa.seq2seq.BahdanauAttention and tfa.seq2seq.AttentionWrapper layers from the model, the model can save successfully.
Code to reproduce the issue
I create a colab note to reproduce this issue:
https://colab.research.google.com/gist/yang-stressfree/bc78b67ca6f051fe60a7e863b99cc1b3#scrollTo=0jcDbzxVD4-h
import tensorflow as tf
import tensorflow_addons as tfa
import tensorflow.keras as keras
class TinyDemoModel(keras.Model):
def __init__(self, rnn_units, *args, **kwargs):
super(TinyDemoModel, self).__init__(*args, **kwargs)
self.vocab_size = 10
self.rnn_units = rnn_units
self.dec_embedding = keras.layers.Embedding(input_dim=self.vocab_size, output_dim=self.vocab_size)
self.dec_lstm_cell = keras.layers.LSTMCell(units=self.rnn_units)
self.dec_attention = tfa.seq2seq.BahdanauAttention(units=self.rnn_units)
self.dec_rnn_cell = tfa.seq2seq.AttentionWrapper(cell=self.dec_lstm_cell,
attention_mechanism=self.dec_attention,
attention_layer_size=self.rnn_units)
self.dec_fc = keras.layers.Dense(self.vocab_size)
self.dec_train = tfa.seq2seq.BasicDecoder(self.dec_rnn_cell, tfa.seq2seq.sampler.TrainingSampler(),
output_layer=self.dec_fc)
def dec_build_initial_state(self, batch_size, enc_h, enc_c):
initial_state = self.dec_rnn_cell.get_initial_state(batch_size=batch_size, dtype=enc_h.dtype)
initial_state = initial_state.clone(cell_state=[enc_h, enc_c])
return initial_state
def get_config(self):
raise NotImplementedError
def call(self, inputs, training=None, mask=None):
batch_seq_encoded, batch_seq_labeled = inputs
shape_batch_seq_labeled = tf.shape(batch_seq_labeled)
batch_size = shape_batch_seq_labeled[0]
self.dec_attention.setup_memory(batch_seq_encoded)
initial_state = self.dec_build_initial_state(batch_size, tf.zeros([batch_size, self.rnn_units]),
tf.zeros([batch_size, self.rnn_units]))
batch_seq_labeled_embedded = self.dec_embedding(batch_seq_labeled)
output, _, _ = self.dec_train(batch_seq_labeled_embedded, initial_state=initial_state)
# pad and return
batch_seq_predicted_odds = output.rnn_output
pad_size = shape_batch_seq_labeled[1] - tf.shape(batch_seq_predicted_odds)[1]
batch_seq_predicted_odds = tf.pad(batch_seq_predicted_odds, [[0, 0], [0, pad_size], [0, 0]])
return batch_seq_predicted_odds
def train_and_save():
rnn_units = 8
batch_size = 2
batch_seq_encoded = tf.ones([batch_size, 100, rnn_units], dtype=tf.float32)
batch_seq_labeled = tf.ones([batch_size, 100], dtype=tf.int32)
model = TinyDemoModel(rnn_units)
loss_obj = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model.compile(optimizer="adam", loss=loss_obj)
model.fit(x=(batch_seq_encoded, batch_seq_labeled), y=batch_seq_labeled)
model.save(filepath="/tmp/saved_model_tiny_demo_model")
if __name__ == "__main__":
train_and_save()
Run this file, will get:
Traceback (most recent call last):
File "tiny_model.py", line 61, in <module>
train_and_save()
File "tiny_model.py", line 57, in train_and_save
model.save(filepath="/tmp/saved_model_tiny_demo_model")
File "lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 67, in error_handler
raise e.with_traceback(filtered_tb) from None
File "lib/python3.9/site-packages/tensorflow/python/saved_model/save.py", line 402, in map_resources
raise ValueError(
ValueError: Unable to save function b'__inference_tiny_demo_model_layer_call_fn_6516' because it captures graph tensor Tensor("BahdanauAttention/strided_slice:0", shape=(), dtype=int32) from a parent function which cannot be converted to a constant with `tf.get_static_value`.
Other info / logs
Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.
Similar issue https://github.com/tensorflow/addons/issues/2672
Test with:
tensorflow==2.8.0
tensorflow-addons==0.16.1
get the same error 😢
I figure out why this error occur when try to save the model with tf.saved_model.save or model.save:
def call(self, inputs, training=None, mask=None):
batch_seq_encoded, batch_seq_labeled = inputs
# ...
self.dec_attention.setup_memory(batch_seq_encoded)
actually the error explain itself:
ValueError: Unable to save function b'__inference_tiny_demo_model_layer_call_fn_6516' because it captures graph tensor Tensor("BahdanauAttention/strided_slice:0", shape=(), dtype=int32) from a parent function which cannot be converted to a constant with
tf.get_static_value.
Because tf.saved_model.save try to convert batch_seq_encoded which is a dynamic value as memory of attention mechanism to a static value.
For anyone want to export model by tf.saved_model.save, DO NOT implement attention mechanism with tfa.seq2seq.BahdanauAttention and tfa.seq2seq.AttentionWrapper.
As a beginner, https://www.tensorflow.org/text/tutorials/nmt_with_attention is a workable and portable solution.