keras
keras copied to clipboard
Attention layer does not accept output of previous layers in functional API
As an exercise to get acquainted with Keras, I want to train a simple model with attention to translate sentences.
I am not calling a tf function, only using Keras layers. But I get the following error:
A KerasTensor cannot be used as input to a TensorFlow function. A KerasTensor is a symbolic placeholder for a shape and dtype, used when constructing Keras Functional models or Keras Functions. You can only use it as input to a Keras layer or a Keras operation (from the namespaces keras.layers and keras.operations). [...]
Here is the code for the model using Keras' functional API:
encoder_inputs = tf.keras.layers.Input(shape=[], dtype=tf.string)
decoder_inputs = tf.keras.layers.Input(shape=[], dtype=tf.string)
embed_size = 128
encoder_inputs_ids = text_vec_layer_en(encoder_inputs)
decoder_inputs_ids = text_vec_layer_es(decoder_inputs)
encoder_embedding_layer = tf.keras.layers.Embedding(vocab_size, embed_size, mask_zero=True)
decoder_embedding_layer = tf.keras.layers.Embedding(vocab_size, embed_size, mask_zero=True)
encoder_embeddings = encoder_embedding_layer(encoder_inputs_ids)
decoder_embeddings = decoder_embedding_layer(decoder_inputs_ids)
encoder = tf.keras.layers.LSTM(512, return_sequences=True, return_state=True)
encoder_outputs, *encoder_state = encoder(encoder_embeddings)
decoder = tf.keras.layers.LSTM(512, return_sequences=True)
decoder_outputs = decoder(decoder_embeddings, initial_state=encoder_state)
# Attention layer here!
# Problems getting it to work on Keras 3
attention_layer = tf.keras.layers.Attention()
attention_outputs = attention_layer([decoder_outputs, encoder_outputs])
output_layer = tf.keras.layers.Dense(vocab_size, activation="softmax")
Y_probas = output_layer(attention_outputs)
Expected behavior: The Keras attention layer accepts Keras tensor inputs. Or a more helpful error message is given.
Python version: 3.11.0 Tensorflow version: 2.17.0 Keras version: 3.4.1 (bundled with that Tensorflow version)
Codecov Report
All modified and coverable lines are covered by tests :white_check_mark:
Project coverage is 58.32%. Comparing base (
0d68b65) to head (a5d2cf7).
Additional details and impacted files
@@ Coverage Diff @@
## master #5640 +/- ##
============================================
+ Coverage 58.20% 58.32% +0.11%
- Complexity 3693 3699 +6
============================================
Files 549 549
Lines 15850 15841 -9
Branches 3015 3014 -1
============================================
+ Hits 9226 9239 +13
+ Misses 6232 6209 -23
- Partials 392 393 +1
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
Thanks!