transformers
transformers copied to clipboard
BertLMHeadModel (w/ relative position embedding) does not work correctly when use_cache = True
System Info
-
transformers
version: 4.20.1 - Platform: Linux-5.4.0-92-generic-x86_64-with-glibc2.17
- Python version: 3.8.13
- Huggingface_hub version: 0.8.1
- PyTorch version (GPU?): 1.12.0 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: No
- Using distributed or parallel set-up in script?: No
Who can help?
@LysandreJik
Information
- [ ] The official example scripts
- [X] My own modified scripts
Tasks
- [ ] An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - [X] My own task or dataset (give details below)
Reproduction
I found that BertLMHeadModel
(w/ relative position embedding) sometimes generates unexpected sequences when use_cache = True
.
Here is a minimal code sample that indirectly demonstrates this problem:
import torch
from transformers import BertConfig, BertLMHeadModel
config = BertConfig(
is_decoder=True, vocab_size=10, hidden_size=64, num_hidden_layers=1, num_attention_heads=4,
intermediate_size=64, position_embedding_type='relative_key')
model = BertLMHeadModel(config).eval()
with torch.no_grad():
model.config.use_cache = False
generation = model.generate(bos_token_id=1, max_length=5, output_attentions=True, return_dict_in_generate=True)
print(generation.attentions[-1][0][:, :, -1:, :])
prediction = model(input_ids=generation.sequences[:, :-1], output_attentions=True)
print(prediction.attentions[0][:, :, -1:, :])
model.config.use_cache = True
generation = model.generate(bos_token_id=1, max_length=5, output_attentions=True, return_dict_in_generate=True)
print(generation.attentions[-1][0])
Outputs:
tensor([[[[0.2455, 0.2530, 0.2558, 0.2457]],
[[0.2495, 0.2492, 0.2497, 0.2516]],
[[0.2481, 0.2516, 0.2514, 0.2489]],
[[0.2496, 0.2538, 0.2533, 0.2433]]]])
tensor([[[[0.2455, 0.2530, 0.2558, 0.2457]],
[[0.2495, 0.2492, 0.2497, 0.2516]],
[[0.2481, 0.2516, 0.2514, 0.2489]],
[[0.2496, 0.2538, 0.2533, 0.2433]]]])
tensor([[[[0.2452, 0.2532, 0.2548, 0.2468]],
[[0.2498, 0.2492, 0.2494, 0.2516]],
[[0.2485, 0.2516, 0.2516, 0.2483]],
[[0.2492, 0.2538, 0.2528, 0.2442]]]])
Expected behavior
The three printed attention tensors must have the same values, but different values. (The generated sequences are all the same in this case, but as the model is trained, different sequences are generated according to use_cache
.)
The cause of this problem is that BertSelfAttention
's relative position embedding does not handle use_cache = True
case properly. It seems that this problem can be fixed by modifying BertSelfAttention
's forward
function as follows:
# ...
use_cache = past_key_value is not None
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
if use_cache:
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(-1, 1)
else:
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r
# ...
(The current code always makes the distance
variable become tensor([[0]])
when use_cache = True
.)
Other models using the same code also need modifications...
Also, BertLMHeadModel
's generate
function does not overwrite the use_cache
option. It seems that BertLMHeadModel
's prepare_inputs_for_generation
function should add use_cache
item to the output dictionary similar to this.
Thanks for opening an issue! @ArthurZucker or @ydshieh, could you take a look at what might be going on here?
@jsh710101 Before going deeper, could you also try without relative position embedding, and post the results here?
Thank you for your quick reply, of course!
With absolute position embedding (position_embedding_type = 'absolute'
), the minimal code sample outputs:
tensor([[[[0.2540, 0.2539, 0.2472, 0.2449]],
[[0.2519, 0.2480, 0.2483, 0.2518]],
[[0.2473, 0.2475, 0.2517, 0.2535]],
[[0.2496, 0.2523, 0.2491, 0.2491]]]])
tensor([[[[0.2540, 0.2539, 0.2472, 0.2449]],
[[0.2519, 0.2480, 0.2483, 0.2518]],
[[0.2473, 0.2475, 0.2517, 0.2535]],
[[0.2496, 0.2523, 0.2491, 0.2491]]]])
tensor([[[[0.2540, 0.2539, 0.2472, 0.2449]],
[[0.2519, 0.2480, 0.2483, 0.2518]],
[[0.2473, 0.2475, 0.2517, 0.2535]],
[[0.2496, 0.2523, 0.2491, 0.2491]]]])
The three attention tensors are the same as expected. (I checked the code and it seems to be implemented correctly.)
To be specific, if we are generating 3rd token (w/ relative position embedding & use_cache = True
),
the
distance
tensor should be tensor([[2, 1, 0]])
, but the current implementation (code below) always makes it tensor([[0]])
because seq_length
is always assigned 1.
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
seq_length = hidden_states.size()[1]
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r
@ydshieh Ah... I forgot to tag you.
I think I can fix this problem. If you agree that the above should be handled and you're not working on it, I'll try to fix the code and open a pull request.
Hey, I am currently investigating whether we should indeed change the attention or not. As a lot of models depend from it, I wanna make sure this would be backward compatible! But if you want , feel free to open a PR. 😄
Hey! So after investigating in detail, it seems that we indeed have problem, but the good new is that it is not a major issue.
First, we have to use a model that was trained with relative_key
, so I used "zhiheng-huang/bert-base-uncased-embedding-relative-key"
.
- The attention scores are indeed different, but the result of the softmax (the last logits are different) is always the same. This seem to come from the learned embedding that doesn't seem to have a huge impact (when the model already has learned) but could impact the training.
Minimal reproducing script :
import torch
from transformers import BertTokenizer, BertLMHeadModel, set_seed
tokenizer = BertTokenizer.from_pretrained("zhiheng-huang/bert-base-uncased-embedding-relative-key")
model = BertLMHeadModel.from_pretrained("zhiheng-huang/bert-base-uncased-embedding-relative-key", is_decoder = True)
inputs = tokenizer("No I'm not missing the ", return_tensors="pt")
input_ids = inputs.input_ids[:,:-1]
attention_mask = inputs.attention_mask[:,:-1]
with torch.no_grad():
model.config.use_cache = False
set_seed(0)
output = model(input_ids, attention_mask = attention_mask, use_cache =False)
print(output.logits[:,-1,:])
model.config.use_cache = True
output_1 = model(input_ids[:,:-1], use_cache = True, attention_mask = attention_mask[:,:-1])
pkv = output_1.past_key_values
output_2 = model(input_ids[:,-1:], past_key_values = pkv , use_cache = True)
print(output_2.logits[:,-1,:])
tensor([[-5.4971, -6.4888, -8.3359, ..., -7.3612, -5.5480, -0.9784]])
tensor([[ -7.2693, -7.7799, -10.0905, ..., -7.5183, -7.4255, -4.6804]])
With your fix we indeed have
tensor([[-5.4971, -6.4888, -8.3359, ..., -7.3612, -5.5480, -0.9784]])
tensor([[-5.4971, -6.4888, -8.3359, ..., -7.3612, -5.5480, -0.9784]])
This should have been tested when merging the model, but it seems like it was not. I will open a PR to address this.