keras-nlp
keras-nlp copied to clipboard
The attention scores are always `None` in `CachedMultiHeadAttention`
Describe the bug
The variable attention_scores introduced at line 111 is always None.
To Reproduce
Since it is an internal variable, I copied the subclass CMHA in this script: https://colab.research.google.com/drive/1ZUS4mjDQktovKiJ8TQ7zYtm4PGjesXvG?usp=sharing
Expected behavior
The variable attention_scores should contain the cross correlation between query and key, which is useful for debugging a model IMHO.
Additional context
In recent Keras versions, the parent class MultiHeadAttention saves the argument return_attention_scores in self._return_attention_scores.
Then, the method _compute_attention checks this private property to decide whether or not to return the scores.
Since this state is not updated in CachedMultiHeadAttention.call, the attention scores will never be returned.
I'll also submit an issue to Keras to turn the attribute _return_attention_scores into an argument.
Would you like to help us fix it?
Yes, I have two potential fixes:
- ignore attention scores entirely, which is consistent since the corresponding argument has been removed from CMHA
- add the relevant argument and set the class property
_return_attention_scoresaccordingly
WDYT?