transformers icon indicating copy to clipboard operation
transformers copied to clipboard

GPT Neox rotary embedding does not work with padding left

Open OlivierDehaene opened this issue 2 years ago • 6 comments

System Info

  • transformers version: 4.26.1
  • Platform: Linux-5.4.0-1097-aws-x86_64-with-glibc2.27
  • Python version: 3.10.9
  • Huggingface_hub version: 0.12.1
  • PyTorch version (GPU?): 1.13.1+cu117 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: no

Who can help?

@ArthurZucker, @younesbelkada, @gante

Information

  • [ ] The official example scripts
  • [X] My own modified scripts

Tasks

  • [ ] An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • [X] My own task or dataset (give details below)

Reproduction

import torch

from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("OpenAssistant/oasst-sft-1-pythia-12b", padding_side="left")

model = AutoModelForCausalLM.from_pretrained("OpenAssistant/oasst-sft-1-pythia-12b", device_map="auto")

f_not_padded = model.forward(**tokenizer(["<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>"], padding=False, return_tensors="pt"))
f_padded = model.forward(**tokenizer(["<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>"], padding=True, pad_to_multiple_of=256, return_tensors="pt"))

torch.testing.assert_allclose(f_not_padded.logits[:, -1], f_padded.logits[:, -1])

# AssertionError: Tensor-likes are not close!

# Mismatched elements: 6057 / 50288 (12.0%)
# Greatest absolute difference: 0.0003177821636199951 at index (0, 4649) (up to 1e-05 allowed)
# Greatest relative difference: 1.5682868874196898 at index (0, 30410) (up to 0.0001 allowed)

The problem is exacerbated in bfloat16

import torch

from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("OpenAssistant/oasst-sft-1-pythia-12b", padding_side="left")

model = AutoModelForCausalLM.from_pretrained("OpenAssistant/oasst-sft-1-pythia-12b", device_map="auto", torch_dtype=torch.bfloat16)

f_not_padded = model.forward(**tokenizer(["<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>"], padding=False, return_tensors="pt"))
f_padded = model.forward(**tokenizer(["<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>"], padding=True, pad_to_multiple_of=256, return_tensors="pt"))

torch.testing.assert_allclose(f_not_padded.logits[:, -1], f_padded.logits[:, -1])

# AssertionError: Tensor-likes are not equal!

# Mismatched elements: 49417 / 50288 (98.3%)
# Greatest absolute difference: 1.154541015625 at index (0, 50271)
# Greatest relative difference: 2058.906976744186 at index (0, 29917)

Expected behavior

padding left should have no influence on the resulting logits.

While the differences do not look like much, it has a huge impact on generation.

OlivierDehaene avatar Mar 14 '23 14:03 OlivierDehaene

Hey thanks for reporting!

ArthurZucker avatar Mar 14 '23 15:03 ArthurZucker

It is possible that this is not the root cause but there is an issue with these lines:

offset = 0
if has_layer_past:
    offset = layer_past[0].shape[-2]
    seq_len += offset
cos, sin = self.rotary_emb(value, seq_len=seq_len)
query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, offset=offset)

offset and seq_len are not computed correctly when you have padding. On a sidenote, it is impossible to have a single value for offset as different sequences in the batch might have different length and therefore different offsets when padding left.

OlivierDehaene avatar Mar 15 '23 09:03 OlivierDehaene

We use padding left extensively on the serving side as we have a dynamic batching logic that batches sequence of very different lengths together.

While the pad==256 example above seems extreme in isolation, it is completely normal when serving. We sometimes even go higher in chat applications where a member of the batch has a very large history (> 1000 tokens) and other sequences only just started ( ~ 40 tokens).

We also serve all the models in bfloat16 if available and we almost always use sampling which amplifies the logits issue even more.

OlivierDehaene avatar Mar 15 '23 09:03 OlivierDehaene

Hey everyone! Yes, it is correct, it is pretty much the same issue as I reported here -- we should be passing position_ids all the way down to the attention layer, and compute the sequence length from it.

We have an open PR to fix the same issue with GPT-J (#22069), I'll make sure it is ported to GPT NeoX when it is merged. We are currently ironing out torch.fx issues (adding the correct behavior makes the tensors dynamic, which blocks existing features)

gante avatar Mar 15 '23 09:03 gante

Hi @OlivierDehaene, I'm actually in the middle of porting the fix from #22069 to GPT-Neox too, since I was also interested in that one (in parallel with other things including resolving this torch.fx issue).

Also for reference there's a similar existing issue which went stale: https://github.com/huggingface/transformers/issues/18999

njhill avatar Mar 15 '23 12:03 njhill

Hi @njhill! Nice thanks for working on this! For now I have a fix on my text-generation-inference fork as we have multiple neox in prod and I need a fix asap. It's sensibly the same to yours I think.

class RotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, max_position_embeddings, base=10000, device=None):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
        self.register_buffer("inv_freq", inv_freq)

        # Build here to make `torch.jit.trace` work.
        self.max_seq_len_cached = max_position_embeddings
        self.cos_cached = None
        self.sin_cached = None

    @staticmethod
    def rotate_half(x):
        """Rotates half the hidden dims of the input."""
        x1 = x[..., : x.shape[-1] // 2]
        x2 = x[..., x.shape[-1] // 2 :]
        return torch.cat((-x2, x1), dim=-1)

    @staticmethod
    def _create_cos_sin(inv_freq, max_position_embeddings, dtype, device):
        t = torch.arange(max_position_embeddings, device=inv_freq.device, dtype=inv_freq.dtype)
        freqs = torch.einsum("i,j->ij", t, inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        return emb.cos().to(device).to(dtype), emb.sin().to(device).to(dtype)

    def forward(self, q, k, position_ids, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        if seq_len > self.max_seq_len_cached or self.cos_cached is None or self.sin_cached is None:
            if seq_len > self.max_seq_len_cached:
                self.max_seq_len_cached = seq_len
            self.cos_cached, self.sin_cached = self._create_cos_sin(
                self.inv_freq, self.max_seq_len_cached, q.dtype, q.device
            )
        cos = self.cos_cached[position_ids].unsqueeze(1)
        sin = self.sin_cached[position_ids].unsqueeze(1)

        q_embed = (q * cos) + (rotate_half(q) * sin)
        k_embed = (k * cos) + (rotate_half(k) * sin)
        return q_embed, k_embed

OlivierDehaene avatar Mar 15 '23 12:03 OlivierDehaene

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Apr 13 '23 15:04 github-actions[bot]

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar May 09 '23 15:05 github-actions[bot]