tf-keras
tf-keras copied to clipboard
MultiHeadAttention masking vector gets saved as list that raises an exception when loading model.
Please go to TF Forum for help and support:
https://discuss.tensorflow.org/tag/keras
If you open a GitHub issue, here is our policy:
It must be a bug, a feature request, or a significant problem with the documentation (for small docs fixes please send a PR instead). The form below must be filled out.
Here's why we have that policy:.
Keras developers respond to issues. We want to focus on work that benefits the whole community, e.g., fixing bugs and adding features. Support only helps individuals. GitHub also notifies thousands of people when issues are filed. We want them to see you communicating an interesting problem, rather than being redirected to Stack Overflow.
System information.
- Have I written custom code (as opposed to using a stock example script provided in Keras): YES
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux Ubuntu 18.04.5 LTS
- TensorFlow installed from (source or binary): binary
- TensorFlow version (use command below): 2.9.1
- Python version: 3.9.7
- Bazel version (if compiling from source):
- GPU model and memory: None
- Exact command to reproduce: running the MRE provided below
You can collect some of this information using our environment capture script:
https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh
You can obtain the TensorFlow version with: python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"
Describe the problem.
When using the MultiHeadAttention keras layer with a masking vector, it is cast into a list for saving the model to a file. When loading the model with keras.models.load_model, there is an AttributeError exception as the masking vector is not cast back to a np.array.
Describe the current behavior.
Mask vector is a list when a model is loaded with keras.models.load_model
Describe the expected behavior.
The mask vector should be cast back from a list to a numpy array when loaded with keras.models.load_model.
- Do you want to contribute a PR? (yes/no): yes
- If yes, please read this page for instructions
- Briefly describe your candidate solution(if contributing):
Standalone code to reproduce the issue.
from keras.layers import MultiHeadAttention, Input, Dense
from keras.models import Model
import tensorflow as tf
import numpy as np
input_layer = Input(shape=(2, 10))
look_ahead_mask = np.ones((2, 2))
look_ahead_mask = np.tril(look_ahead_mask, 0)
look_ahead_mask = tf.constant(look_ahead_mask)
model = MultiHeadAttention(num_heads=2, key_dim=2) (
input_layer, input_layer, input_layer, look_ahead_mask
)
model = Dense(1) (model)
model = Model(inputs=input_layer, outputs=model, name="model")
model.compile(loss='mse')
X = [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]
X = np.array(X).reshape(1, 2, 10)
Y = np.array([.5]).reshape(-1, 1)
model.fit(X, Y)
model.save("model_test.h5")
from keras.models import load_model
model2 = load_model("model_test.h5")
Exception raised after load_model
Source code / logs.
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Input In [10], in <module>
----> 1 model2 = load_model("model_test.h5")
File /opt/conda/lib/python3.9/site-packages/keras/utils/traceback_utils.py:67, in filter_traceback.<locals>.error_handler(*args, **kwargs)
65 except Exception as e: # pylint: disable=broad-except
66 filtered_tb = _process_traceback_frames(e.__traceback__)
---> 67 raise e.with_traceback(filtered_tb) from None
68 finally:
69 del filtered_tb
File /opt/conda/lib/python3.9/site-packages/keras/layers/attention/multi_head_attention.py:435, in MultiHeadAttention._masked_softmax(self, attention_scores, attention_mask)
431 if attention_mask is not None:
432 # The expand dim happens starting from the `num_heads` dimension,
433 # (<batch_dims>, num_heads, <query_attention_dims, key_attention_dims>)
434 mask_expansion_axis = -len(self._attention_axes) * 2 - 1
--> 435 for _ in range(len(attention_scores.shape) - len(attention_mask.shape)):
436 attention_mask = tf.expand_dims(
437 attention_mask, axis=mask_expansion_axis)
438 return self._softmax(attention_scores, attention_mask)
AttributeError: Exception encountered when calling layer "multi_head_attention" (type MultiHeadAttention).
'list' object has no attribute 'shape'
Call arguments received by layer "multi_head_attention" (type MultiHeadAttention):
• query=tf.Tensor(shape=(None, 2, 10), dtype=float32)
• value=tf.Tensor(shape=(None, 2, 10), dtype=float32)
• key=tf.Tensor(shape=(None, 2, 10), dtype=float32)
• attention_mask=[['1.0', '0.0'], ['1.0', '1.0']]
• return_attention_scores=False
• training=False
@gadagashwini I was able to replicate the issue on colab, please find the gist here. Thank you!
Hi @Twansklaf, Looks like one of your input is list, can you convert it numpy array and try. Thank you!
@gadagashwini this is exactly the issue, please read the description and MRE. The list is read from the saved model but is a numpy array then a tensor when the model is created. See in the code example you can
see look_ahead_mask = np.ones((2, 2))
look_ahead_mask = np.tril(look_ahead_mask, 0)
look_ahead_mask = tf.constant(look_ahead_mask)
Up. I'd be happy to push a PR and fix some code if someone would direct me towards the right direction.
Up. I'd be happy to push a PR and fix some code if someone would direct me towards the right direction.
I would recommend simply adding a convert_to_tensor call when receiving a non-None mask in MultiHeadAttention.__init__. Thank you!
To fix this issue, I added the following lines at the beggining of the call() method in the MultiHeadAttention class:
if attention_mask is not None:
attention_mask = tf.cast(tf.convert_to_tensor(attention_mask), tf.bool)
why is this still open / unsolved? due to waiting for a PR?
The issue is again occurring in the latest tensorflow release ( 2.16.1). Older version (2.15.0) was working fine
@adt5-uiuc, I tried to execute the code on latest keras3.0 with the related api's, and observed that the code was executed without any issue/error. Kindly find the gist of it here. Thank you!
This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.
This issue was closed because it has been inactive for 28 days. Please reopen if you'd like to work on this further.