transformers icon indicating copy to clipboard operation
transformers copied to clipboard

BertLMHeadModel (w/ relative position embedding) does not work correctly when use_cache = True

Open jsh710101 opened this issue 2 years ago • 3 comments

System Info

  • transformers version: 4.20.1
  • Platform: Linux-5.4.0-92-generic-x86_64-with-glibc2.17
  • Python version: 3.8.13
  • Huggingface_hub version: 0.8.1
  • PyTorch version (GPU?): 1.12.0 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: No
  • Using distributed or parallel set-up in script?: No

Who can help?

@LysandreJik

Information

  • [ ] The official example scripts
  • [X] My own modified scripts

Tasks

  • [ ] An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • [X] My own task or dataset (give details below)

Reproduction

I found that BertLMHeadModel (w/ relative position embedding) sometimes generates unexpected sequences when use_cache = True.

Here is a minimal code sample that indirectly demonstrates this problem:

import torch
from transformers import BertConfig, BertLMHeadModel

config = BertConfig(
    is_decoder=True, vocab_size=10, hidden_size=64, num_hidden_layers=1, num_attention_heads=4,
    intermediate_size=64, position_embedding_type='relative_key')
model = BertLMHeadModel(config).eval()

with torch.no_grad():
    model.config.use_cache = False
    generation = model.generate(bos_token_id=1, max_length=5, output_attentions=True, return_dict_in_generate=True)
    print(generation.attentions[-1][0][:, :, -1:, :])

    prediction = model(input_ids=generation.sequences[:, :-1], output_attentions=True)
    print(prediction.attentions[0][:, :, -1:, :])

    model.config.use_cache = True
    generation = model.generate(bos_token_id=1, max_length=5, output_attentions=True, return_dict_in_generate=True)
    print(generation.attentions[-1][0])

Outputs:

tensor([[[[0.2455, 0.2530, 0.2558, 0.2457]],

         [[0.2495, 0.2492, 0.2497, 0.2516]],

         [[0.2481, 0.2516, 0.2514, 0.2489]],

         [[0.2496, 0.2538, 0.2533, 0.2433]]]])
tensor([[[[0.2455, 0.2530, 0.2558, 0.2457]],

         [[0.2495, 0.2492, 0.2497, 0.2516]],

         [[0.2481, 0.2516, 0.2514, 0.2489]],

         [[0.2496, 0.2538, 0.2533, 0.2433]]]])
tensor([[[[0.2452, 0.2532, 0.2548, 0.2468]],

         [[0.2498, 0.2492, 0.2494, 0.2516]],

         [[0.2485, 0.2516, 0.2516, 0.2483]],

         [[0.2492, 0.2538, 0.2528, 0.2442]]]])

Expected behavior

The three printed attention tensors must have the same values, but different values. (The generated sequences are all the same in this case, but as the model is trained, different sequences are generated according to use_cache.)

The cause of this problem is that BertSelfAttention's relative position embedding does not handle use_cache = True case properly. It seems that this problem can be fixed by modifying BertSelfAttention's forward function as follows:

# ...

use_cache = past_key_value is not None
if self.is_decoder:
    # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
    # Further calls to cross_attention layer can then reuse all cross-attention
    # key/value_states (first "if" case)
    # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
    # all previous decoder key/value_states. Further calls to uni-directional self-attention
    # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
    # if encoder bi-directional self-attention `past_key_value` is always `None`
    past_key_value = (key_layer, value_layer)

# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
    query_length, key_length = query_layer.shape[2], key_layer.shape[2]
    if use_cache:
        position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(-1, 1)
    else:
        position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
    position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
    distance = position_ids_l - position_ids_r

    # ...

(The current code always makes the distance variable become tensor([[0]]) when use_cache = True.)

Other models using the same code also need modifications...

Also, BertLMHeadModel's generate function does not overwrite the use_cache option. It seems that BertLMHeadModel's prepare_inputs_for_generation function should add use_cache item to the output dictionary similar to this.

jsh710101 avatar Sep 15 '22 05:09 jsh710101

Thanks for opening an issue! @ArthurZucker or @ydshieh, could you take a look at what might be going on here?

LysandreJik avatar Sep 15 '22 13:09 LysandreJik

@jsh710101 Before going deeper, could you also try without relative position embedding, and post the results here?

ydshieh avatar Sep 16 '22 12:09 ydshieh

Thank you for your quick reply, of course!

With absolute position embedding (position_embedding_type = 'absolute'), the minimal code sample outputs:

tensor([[[[0.2540, 0.2539, 0.2472, 0.2449]],

         [[0.2519, 0.2480, 0.2483, 0.2518]],

         [[0.2473, 0.2475, 0.2517, 0.2535]],

         [[0.2496, 0.2523, 0.2491, 0.2491]]]])
tensor([[[[0.2540, 0.2539, 0.2472, 0.2449]],

         [[0.2519, 0.2480, 0.2483, 0.2518]],

         [[0.2473, 0.2475, 0.2517, 0.2535]],

         [[0.2496, 0.2523, 0.2491, 0.2491]]]])
tensor([[[[0.2540, 0.2539, 0.2472, 0.2449]],

         [[0.2519, 0.2480, 0.2483, 0.2518]],

         [[0.2473, 0.2475, 0.2517, 0.2535]],

         [[0.2496, 0.2523, 0.2491, 0.2491]]]])

The three attention tensors are the same as expected. (I checked the code and it seems to be implemented correctly.)

To be specific, if we are generating 3rd token (w/ relative position embedding & use_cache = True), image the distance tensor should be tensor([[2, 1, 0]]), but the current implementation (code below) always makes it tensor([[0]]) because seq_length is always assigned 1.

if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
    seq_length = hidden_states.size()[1]
    position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
    position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
    distance = position_ids_l - position_ids_r

jsh710101 avatar Sep 16 '22 13:09 jsh710101

@ydshieh Ah... I forgot to tag you.

I think I can fix this problem. If you agree that the above should be handled and you're not working on it, I'll try to fix the code and open a pull request.

jsh710101 avatar Oct 16 '22 14:10 jsh710101

Hey, I am currently investigating whether we should indeed change the attention or not. As a lot of models depend from it, I wanna make sure this would be backward compatible! But if you want , feel free to open a PR. 😄

ArthurZucker avatar Oct 16 '22 16:10 ArthurZucker

Hey! So after investigating in detail, it seems that we indeed have problem, but the good new is that it is not a major issue.

First, we have to use a model that was trained with relative_key, so I used "zhiheng-huang/bert-base-uncased-embedding-relative-key".

  • The attention scores are indeed different, but the result of the softmax (the last logits are different) is always the same. This seem to come from the learned embedding that doesn't seem to have a huge impact (when the model already has learned) but could impact the training.

Minimal reproducing script :

import torch
from transformers import BertTokenizer, BertLMHeadModel, set_seed
tokenizer = BertTokenizer.from_pretrained("zhiheng-huang/bert-base-uncased-embedding-relative-key")
model = BertLMHeadModel.from_pretrained("zhiheng-huang/bert-base-uncased-embedding-relative-key", is_decoder = True)
inputs = tokenizer("No I'm not missing the ", return_tensors="pt")
input_ids = inputs.input_ids[:,:-1]
attention_mask = inputs.attention_mask[:,:-1]

with torch.no_grad():
    model.config.use_cache = False
    set_seed(0)
    output = model(input_ids, attention_mask = attention_mask, use_cache =False)
    print(output.logits[:,-1,:])

    model.config.use_cache = True
    output_1 = model(input_ids[:,:-1], use_cache = True, attention_mask = attention_mask[:,:-1])
    pkv = output_1.past_key_values
    output_2 = model(input_ids[:,-1:], past_key_values = pkv , use_cache = True)
    print(output_2.logits[:,-1,:])

tensor([[-5.4971, -6.4888, -8.3359,  ..., -7.3612, -5.5480, -0.9784]])
tensor([[ -7.2693,  -7.7799, -10.0905,  ...,  -7.5183,  -7.4255,  -4.6804]])

With your fix we indeed have

tensor([[-5.4971, -6.4888, -8.3359,  ..., -7.3612, -5.5480, -0.9784]])
tensor([[-5.4971, -6.4888, -8.3359,  ..., -7.3612, -5.5480, -0.9784]])

This should have been tested when merging the model, but it seems like it was not. I will open a PR to address this.

ArthurZucker avatar Nov 14 '22 10:11 ArthurZucker