transformers icon indicating copy to clipboard operation
transformers copied to clipboard

gpt2 results with past_key_values not the same as when computed from scratch

Open IanMagnusson opened this issue 2 years ago • 7 comments

System Info

  • transformers version: 4.20.1
  • Platform: Linux-5.4.0-89-generic-x86_64-with-glibc2.31
  • Python version: 3.9.12
  • Huggingface_hub version: 0.8.1
  • PyTorch version (GPU?): 1.12.0+cu113 (False)
  • Tensorflow version (GPU?): 2.9.1 (False)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: no
  • Using distributed or parallel set-up in script?: no

Who can help?

@patil-suraj @patrickvonplaten @LysandreJik

Information

  • [ ] The official example scripts
  • [X] My own modified scripts

Tasks

  • [ ] An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • [X] My own task or dataset (give details below)

Reproduction

Below is a minimal example that reproduces this unexpected behavior I encountered while tinkering with past_key_values. Essentially when I cache keys and values from a padded batch and then use past_key_values to run forward on an additional token for each example in the batch, I get somewhat different results than if I just compute the whole inputs from scratch and look at the last tokens.

It seems that something is going wrong when past_key_values involves some padding, however I believe I am using attention_mask correctly by including the masking strategy that was used for past_key_values as specified in the docs.

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model = AutoModelForCausalLM.from_pretrained('gpt2')
tokenizer = AutoTokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token

s = ["a b c", "l m n o"]
inputs1 = tokenizer(s, return_tensors='pt', padding=True)
outputs1 = model(**inputs1)

s = [" d", " p"]
inputs2 = tokenizer(s, return_tensors='pt', padding=True)
attention_mask = torch.cat((inputs1['attention_mask'], inputs2['attention_mask']), dim=1)
outputs2 = model(input_ids=inputs2['input_ids'], attention_mask=attention_mask, past_key_values=outputs1.past_key_values)

s = ["a b c d", "l m n o p"]
inputs_full = tokenizer(s, return_tensors='pt', padding=True)
outputs_full = model(**inputs_full)

assert torch.allclose(outputs2.logits[1,0],outputs_full.logits[1,-1]) # are second example last token logits the same? -> passes
assert torch.allclose(outputs2.logits[0,0], outputs_full.logits[0,-2]) # are first example last token logits the same? -> fails

Expected behavior

The expected behavior would be for the logits of given tokens to be the same regardless of whether past_key_values is used for preceding tokens or if the full inputs are computed from scratch.

Thanks so much for all your hard work on this great library!

IanMagnusson avatar Jul 12 '22 00:07 IanMagnusson

On further inspection, I believe the source of the difference is the position_ids. When the batched and padded past_key_values are used, the default position_ids are computed by this code:

      if past_key_values is None:
            past_length = 0
            past_key_values = tuple([None] * len(self.h))
      else:
          past_length = past_key_values[0][0].size(-2)
      if position_ids is None:
          position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
          position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])

Because the past_length includes the padded parts of past_key_values, this will cause the position_ids for the new tokens to be different than if everything is computed from scratch.

I tested and if you modify my minimal example in the original post with position_ids = torch.tensor([[3],[4]],dtype=torch.int64) and pass that to the model forward pass, both asserts now pass. So just manually specifying the position_ids solves this problem.

IanMagnusson avatar Jul 12 '22 20:07 IanMagnusson

I won't have time to look into this I'm afraid. @ArthurZucker could you give it a try?

patrickvonplaten avatar Sep 27 '22 11:09 patrickvonplaten

Yep, I will have a look asap

ArthurZucker avatar Sep 27 '22 13:09 ArthurZucker

So! Sorry for the late reply. My first answer would be that the attention_mask and the inputs are different.

  • In the first case, you are feeding [ 64, 275, 269, 50256] and then [288] with the combined attention mask : [1, 1, 1, 0, 1].
  • In the second case, you are feeding [ 64, 275, 269, 288, 50256] with attention mask [1, 1, 1, 1, 0].

I thought that using padding_side='left' would fix it, let me investigate!

ArthurZucker avatar Dec 13 '22 18:12 ArthurZucker

Okay this fixes it for me :

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model = AutoModelForCausalLM.from_pretrained('gpt2')
tokenizer = AutoTokenizer.from_pretrained('gpt2', padding_side="left")
tokenizer.pad_token = tokenizer.eos_token

s = ["a b c", "l m n o p q r s t u v w x y"]
inputs1 = tokenizer(s, return_tensors='pt', padding=True)
# First sequence is indeed padded : [   64,   275,   269, 50256]
outputs1 = model(**inputs1)

s = [" d", " z"]
inputs2 = tokenizer(s, return_tensors='pt', padding=True)
attention_mask = torch.cat((inputs1['attention_mask'], inputs2['attention_mask']), dim=-1)
# outputs1.past_key_values[0][0].shape
# torch.Size([2, 12, 4, 64])
outputs2 = model(input_ids=inputs2['input_ids'], attention_mask=attention_mask, past_key_values=outputs1.past_key_values)

s = ["a b c d", "l m n o p q r s t u v w x y z"]
inputs_full = tokenizer(s, return_tensors='pt', padding=True)
outputs_full = model(**inputs_full)


assert torch.allclose(outputs2.logits[1,0],outputs_full.logits[1,-1]) # are second example last token logits the same? -> passes
assert torch.allclose(outputs2.logits[0,0], outputs_full.logits[0,-1]) # are first example last token logits the same? -> fails

ArthurZucker avatar Dec 13 '22 18:12 ArthurZucker

@ArthurZucker thanks for looking into this! Yes using padding_side="left" seems like a great solution to this issue!

I'm curious what is the intended path for users to figure out this usage? I can see how the most common use case for past_key_values is sequential decoding, in which case batched generation will already mandate left padding. However there may be some other users like myself that are using past_key_values to compute likelihoods of a set of reference texts that all have some shared prefix that can be cached with past_key_values. In that case, the necessity of left padding wont emerge until one considers what will happen to the position_ids as we have here.

I wonder if the documentation for the past_key_values and attention_mask parameters of forward could mention that left padding will preserve the position_ids. Below is a possibility with changes in bold. It's just a thought, in case it might be helpful. Thank you for your consideration!

past_key_values (Tuple[Tuple[torch.Tensor]] of length config.n_layers) — Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see past_key_values output below). Can be used to speed up sequential decoding. The input_ids which have their past given to this model should not be passed as input_ids as they have already been computed. However, the attention_mask of given past input_ids does need to be provided (see attention_mask).

attention_mask (torch.FloatTensor of shape (batch_size, sequence_length), optional) — Mask to avoid performing attention on padding token indices. Mask values selected in [0, 1]: 1 for tokens that are not masked, 0 for tokens that are masked. If past_key_values is used, attention_mask needs to contain the masking strategy that was used for past_key_values. In other words, the attention_mask always has to have the length: len(past_key_values) + len(input_ids). For batching with past_key_values, left padding is required to make uninterrupted attention_masks that preserve position_ids.

IanMagnusson avatar Jan 02 '23 05:01 IanMagnusson

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Jan 26 '23 15:01 github-actions[bot]

@ArthurZucker what if the suffixes ([" d", " z"] in the example) have a different number of tokens? I changed the suffixes to [" d e", " z"] and don't get the expected result

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model = AutoModelForCausalLM.from_pretrained('gpt2')
tokenizer = AutoTokenizer.from_pretrained('gpt2', padding_side="left")
tokenizer.pad_token = tokenizer.eos_token

s = ["a b c", "l m n o p q r s t u v w x y"]
inputs1 = tokenizer(s, return_tensors='pt', padding=True)
outputs1 = model(**inputs1)

s = [" d e", # add e
     " z"]
inputs2 = tokenizer(s, return_tensors='pt', padding=True)
attention_mask = torch.cat((inputs1['attention_mask'],
                            inputs2['attention_mask']),
                           dim=-1)
outputs2 = model(input_ids=inputs2['input_ids'],
                 attention_mask=attention_mask,
                 past_key_values=outputs1.past_key_values)

s = ["a b c d e", # add e
     "l m n o p q r s t u v w x y z"]
inputs_full = tokenizer(s, return_tensors='pt', padding=True)
outputs_full = model(**inputs_full)

assert torch.allclose(outputs2.logits[0,-1], outputs_full.logits[0,-1])
# are first example last token logits the same? -> fails

assert torch.allclose(outputs2.logits[1,-1], outputs_full.logits[1,-1])
# are second example last token logits the same? -> fails

Edit: I think I have a general solution. Will add another comment

kddubey avatar Mar 12 '23 11:03 kddubey

A general solution (general meaning: prefixes can have a different number of tokens, and suffixes can have a different number of tokens) is to create and supply position_ids as @IanMagnusson found above. I also think right-padding is the more correct solution b/c prefix position ids are the same as they were if there was no padding.

Demo

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer


model = AutoModelForCausalLM.from_pretrained('gpt2')
tokenizer = AutoTokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token # allow batching
if not tokenizer.padding_side == 'right':
    raise ValueError('Gotta use right padding to ensure position IDs are '
                     'correct.')


prefixes = ['a b c',
            'l m n o p q r s t u v w x y']
# Make sure to start each suffix w/ a whitespace
suffixes = [' d e',
            ' z']


# Batch inference prefixes
prefixes_encoding = tokenizer(prefixes, return_tensors='pt', padding=True)
with torch.no_grad():
    prefixes_out = model(**prefixes_encoding)
# Need offsets so that position_ids for future tokens are set correctly
offsets = prefixes_encoding.attention_mask.sum(dim=1)


# Batch inference suffixes
suffixes_encoding = tokenizer(suffixes, return_tensors='pt',
                              padding=True)
num_completion_tokens = suffixes_encoding.input_ids.shape[1]

# Set position_ids to what they were had we fed each prefix + suffix
# together w/ right-padding (right-padding b/c GPT-2 uses absolute position ids)
suffixes_position_ids = (torch.arange(0, num_completion_tokens) +
                         offsets[:, None]) # broadcast

# Need attention_mask to include the prefixes since it could have padding
attention_mask = torch.cat((prefixes_encoding.attention_mask,
                            suffixes_encoding.attention_mask),
                            dim=1)


# Everything should now be aligned 🤞 🙏
with torch.no_grad():
    suffixes_out = model(input_ids=suffixes_encoding.input_ids,
                         attention_mask=attention_mask,
                         past_key_values=prefixes_out.past_key_values,
                         position_ids=suffixes_position_ids)

Tests


# Expected output
full = [prefix + suffix for prefix, suffix in zip(prefixes, suffixes)]
full_encoding = tokenizer(full, return_tensors='pt', padding=True)
with torch.no_grad():
    full_out = model(**full_encoding)


# Test shape
assert suffixes_out.logits.shape[0]  == full_out.logits.shape[0]
assert suffixes_out.logits.shape[-1] == full_out.logits.shape[-1]


# Test that every non-pad token's logits are close.
# (in the comments, the token in parentheses is the one whose logits we're
#  acessing)
assert torch.allclose(suffixes_out.logits[0, 0], # (d), e
                          full_out.logits[0, 3]) # a, b, c, (d), e, rest are <PAD>

assert torch.allclose(suffixes_out.logits[0, 1], # d, (e)
                          full_out.logits[0, 4]) # a, b, c, d, (e), rest are <PAD>

assert torch.allclose(suffixes_out.logits[1,  0], # (z), <PAD>
                          full_out.logits[1, -1]) # l m n o p q r s t u v w x y (z)

kddubey avatar Mar 12 '23 23:03 kddubey

Hey! Yes as mentioned before, the positional IDS in GPT2 are not created on the fly contrary to other of our models. A fix is in the makinf, see #21853, which should prevent you from having to pass the positional ids.

ArthurZucker avatar Mar 13 '23 07:03 ArthurZucker