transformers
transformers copied to clipboard
GPT Neox rotary embedding does not work with padding left
System Info
transformersversion: 4.26.1- Platform: Linux-5.4.0-1097-aws-x86_64-with-glibc2.27
- Python version: 3.10.9
- Huggingface_hub version: 0.12.1
- PyTorch version (GPU?): 1.13.1+cu117 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: yes
- Using distributed or parallel set-up in script?: no
Who can help?
@ArthurZucker, @younesbelkada, @gante
Information
- [ ] The official example scripts
- [X] My own modified scripts
Tasks
- [ ] An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - [X] My own task or dataset (give details below)
Reproduction
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("OpenAssistant/oasst-sft-1-pythia-12b", padding_side="left")
model = AutoModelForCausalLM.from_pretrained("OpenAssistant/oasst-sft-1-pythia-12b", device_map="auto")
f_not_padded = model.forward(**tokenizer(["<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>"], padding=False, return_tensors="pt"))
f_padded = model.forward(**tokenizer(["<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>"], padding=True, pad_to_multiple_of=256, return_tensors="pt"))
torch.testing.assert_allclose(f_not_padded.logits[:, -1], f_padded.logits[:, -1])
# AssertionError: Tensor-likes are not close!
# Mismatched elements: 6057 / 50288 (12.0%)
# Greatest absolute difference: 0.0003177821636199951 at index (0, 4649) (up to 1e-05 allowed)
# Greatest relative difference: 1.5682868874196898 at index (0, 30410) (up to 0.0001 allowed)
The problem is exacerbated in bfloat16
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("OpenAssistant/oasst-sft-1-pythia-12b", padding_side="left")
model = AutoModelForCausalLM.from_pretrained("OpenAssistant/oasst-sft-1-pythia-12b", device_map="auto", torch_dtype=torch.bfloat16)
f_not_padded = model.forward(**tokenizer(["<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>"], padding=False, return_tensors="pt"))
f_padded = model.forward(**tokenizer(["<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>"], padding=True, pad_to_multiple_of=256, return_tensors="pt"))
torch.testing.assert_allclose(f_not_padded.logits[:, -1], f_padded.logits[:, -1])
# AssertionError: Tensor-likes are not equal!
# Mismatched elements: 49417 / 50288 (98.3%)
# Greatest absolute difference: 1.154541015625 at index (0, 50271)
# Greatest relative difference: 2058.906976744186 at index (0, 29917)
Expected behavior
padding left should have no influence on the resulting logits.
While the differences do not look like much, it has a huge impact on generation.
Hey thanks for reporting!
It is possible that this is not the root cause but there is an issue with these lines:
offset = 0
if has_layer_past:
offset = layer_past[0].shape[-2]
seq_len += offset
cos, sin = self.rotary_emb(value, seq_len=seq_len)
query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, offset=offset)
offset and seq_len are not computed correctly when you have padding.
On a sidenote, it is impossible to have a single value for offset as different sequences in the batch might have different length and therefore different offsets when padding left.
We use padding left extensively on the serving side as we have a dynamic batching logic that batches sequence of very different lengths together.
While the pad==256 example above seems extreme in isolation, it is completely normal when serving. We sometimes even go higher in chat applications where a member of the batch has a very large history (> 1000 tokens) and other sequences only just started ( ~ 40 tokens).
We also serve all the models in bfloat16 if available and we almost always use sampling which amplifies the logits issue even more.
Hey everyone! Yes, it is correct, it is pretty much the same issue as I reported here -- we should be passing position_ids all the way down to the attention layer, and compute the sequence length from it.
We have an open PR to fix the same issue with GPT-J (#22069), I'll make sure it is ported to GPT NeoX when it is merged. We are currently ironing out torch.fx issues (adding the correct behavior makes the tensors dynamic, which blocks existing features)
Hi @OlivierDehaene, I'm actually in the middle of porting the fix from #22069 to GPT-Neox too, since I was also interested in that one (in parallel with other things including resolving this torch.fx issue).
Also for reference there's a similar existing issue which went stale: https://github.com/huggingface/transformers/issues/18999
Hi @njhill! Nice thanks for working on this! For now I have a fix on my text-generation-inference fork as we have multiple neox in prod and I need a fix asap. It's sensibly the same to yours I think.
class RotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings, base=10000, device=None):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.register_buffer("inv_freq", inv_freq)
# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
self.cos_cached = None
self.sin_cached = None
@staticmethod
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
@staticmethod
def _create_cos_sin(inv_freq, max_position_embeddings, dtype, device):
t = torch.arange(max_position_embeddings, device=inv_freq.device, dtype=inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
return emb.cos().to(device).to(dtype), emb.sin().to(device).to(dtype)
def forward(self, q, k, position_ids, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached or self.cos_cached is None or self.sin_cached is None:
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
self.cos_cached, self.sin_cached = self._create_cos_sin(
self.inv_freq, self.max_seq_len_cached, q.dtype, q.device
)
cos = self.cos_cached[position_ids].unsqueeze(1)
sin = self.sin_cached[position_ids].unsqueeze(1)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.