keras icon indicating copy to clipboard operation
keras copied to clipboard

Attention layer does not accept output of previous layers in functional API

Open jorgenorena opened this issue 1 year ago • 2 comments
trafficstars

As an exercise to get acquainted with Keras, I want to train a simple model with attention to translate sentences.

I am not calling a tf function, only using Keras layers. But I get the following error:

A KerasTensor cannot be used as input to a TensorFlow function. A KerasTensor is a symbolic placeholder for a shape and dtype, used when constructing Keras Functional models or Keras Functions. You can only use it as input to a Keras layer or a Keras operation (from the namespaces keras.layers and keras.operations). [...]

Here is the code for the model using Keras' functional API:

encoder_inputs = tf.keras.layers.Input(shape=[], dtype=tf.string)
decoder_inputs = tf.keras.layers.Input(shape=[], dtype=tf.string)

embed_size = 128
encoder_inputs_ids = text_vec_layer_en(encoder_inputs)
decoder_inputs_ids = text_vec_layer_es(decoder_inputs)
encoder_embedding_layer = tf.keras.layers.Embedding(vocab_size, embed_size, mask_zero=True)
decoder_embedding_layer = tf.keras.layers.Embedding(vocab_size, embed_size, mask_zero=True)
encoder_embeddings = encoder_embedding_layer(encoder_inputs_ids)
decoder_embeddings = decoder_embedding_layer(decoder_inputs_ids)

encoder = tf.keras.layers.LSTM(512, return_sequences=True, return_state=True)
encoder_outputs, *encoder_state = encoder(encoder_embeddings)

decoder = tf.keras.layers.LSTM(512, return_sequences=True)
decoder_outputs = decoder(decoder_embeddings, initial_state=encoder_state)

# Attention layer here!
# Problems getting it to work on Keras 3
attention_layer = tf.keras.layers.Attention()
attention_outputs = attention_layer([decoder_outputs, encoder_outputs])

output_layer = tf.keras.layers.Dense(vocab_size, activation="softmax")
Y_probas = output_layer(attention_outputs)

Expected behavior: The Keras attention layer accepts Keras tensor inputs. Or a more helpful error message is given.

Python version: 3.11.0 Tensorflow version: 2.17.0 Keras version: 3.4.1 (bundled with that Tensorflow version)

jorgenorena avatar Oct 02 '24 13:10 jorgenorena

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Project coverage is 58.32%. Comparing base (0d68b65) to head (a5d2cf7).

Additional details and impacted files
@@             Coverage Diff              @@
##             master    #5640      +/-   ##
============================================
+ Coverage     58.20%   58.32%   +0.11%     
- Complexity     3693     3699       +6     
============================================
  Files           549      549              
  Lines         15850    15841       -9     
  Branches       3015     3014       -1     
============================================
+ Hits           9226     9239      +13     
+ Misses         6232     6209      -23     
- Partials        392      393       +1     

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

codecov-commenter avatar Oct 08 '24 10:10 codecov-commenter

Thanks!

Hardvan avatar Oct 10 '24 00:10 Hardvan