transformers
transformers copied to clipboard
Llama inference instability in fp16 producing inf in the middle of the model
System Info
-
transformers
version: 4.35.0.dev0 - Platform: Linux-5.15.0-1023-aws-x86_64-with-glibc2.31
- Python version: 3.9.16
- Huggingface_hub version: 0.17.3
- Safetensors version: 0.3.1
- Accelerate version: 0.25.0.dev0
- Accelerate config: not found
- PyTorch version (GPU?): 2.1.0+cu118 (True)
- Using GPU in script?: A100
Who can help?
@ydshieh @fxmarty @gante
Information
- [X] The official example scripts
- [ ] My own modified scripts
Tasks
- [X] An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - [ ] My own task or dataset (give details below)
Reproduction
Hi, I encounter inference instability with llama running in fp16 when left padding is used, and especially when full rows are masked out in the 4D attention mask.
At some point in the forward, inf
values may appear in the intermediate logits, ultimately leading to tensors filled with nan
and raising the error:
Traceback (most recent call last):
File "=debug.py", line 38, in <module>
outputs = model.generate(
File "/fsx/felix/condaenvs/fx/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/fsx/felix/transformers/src/transformers/generation/utils.py", line 1704, in generate
return self.sample(
File "/fsx/felix/transformers/src/transformers/generation/utils.py", line 2822, in sample
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
RuntimeError: probability tensor contains either `inf`, `nan` or element < 0
Note that the inf
specifically appear at a padding position.
Reproduction:
from transformers import AutoTokenizer, pipeline, logging, AutoModelForCausalLM
import torch
model_name_or_path = "meta-llama/Llama-2-7b-chat-hf"
token = "[specify your token]"
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True, token=token)
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "left"
with torch.device("cuda"):
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16, token=token)
sentence = "Felix Marty is a French"
# Alternatively, the issue can be reproduced with:
# sentence = "Elon Musk is a South"
# max_length=9
inp = tokenizer(sentence, return_tensors='pt', padding="max_length", max_length=9).to("cuda")
print("inp", inp["input_ids"].shape)
print("inp", inp)
torch.set_printoptions(threshold=10000000)
print("\n\n*** Generate:")
with torch.no_grad():
outputs = model.generate(
**inp,
max_new_tokens=10,
do_sample=True,
top_p=0.9,
temperature=float(0.01),
top_k=40
)
print(tokenizer.batch_decode(outputs))
Printing torch.all(torch.isfinite())
at some points in the model, it appears the inf
start to appear in the MLP at self.gate_proj(x)) * self.up_proj(x)
and things go crazy from there.
What's interesting is that for example fixing (two left padding tokens)
to
solves the issue.
It makes me think that the solution implemented for SDPA to avoid fully masked rows in the attention mask may actually be required for some other cases as this one https://github.com/huggingface/transformers/pull/26572 - but it is unclear why it relates to overflow here.
WDYT @gante @ydshieh? Is this something you have ever observed?
Expected behavior
No inf
spawning in the middle of inference with fp16 model
Related to #17937 but there is dummy model.
Will take a look here.
@ArthurZucker has been tracking it, and has a draft PR for it: https://github.com/huggingface/transformers/pull/27114
@fxmarty Can you check if applying this change fixes it?
@ydshieh @gante Thank you! No this PR is unrelated unfortunately, as it also happens when the prompt Elon Musk is a South
with max_length=9
(only one padding token) and the extended attention mask
that does not have any inf.
It may just be instability in the model, but it feels weird that it arises only when some attention mask rows are fully masked.
Well, I guess it needs another deep dive 😬
I haven't been able to give a final conclusion, but in LlamaMLP.forward
, change the else
block to
h1 = self.gate_proj(x)
h2 = self.act_fn(h1)
h3 = self.up_proj(x)
h4 = self.down_proj(h2 * h3)
down_proj = h4
and print their maximal absolute values, we will see their magnitude get unusually larger than before from layer 29 (0-based), and amplified to 255
in layer 30, than h4
get inf
.
The question is what happened in layer 29
for this input: but I am afraid it's just some numerical issue and we don't really have the control.
Will take a further look later when I get spare time.
------------------------------------
layer: 29
h1: 13.234375
h2: 4.5
h3: 18.15625
h4: 57.28125
------------------------------------
layer: 30
h1: 255.875
h2: 255.875
h3: 261.75
h4: inf
------------------------------------
layer: 31
h1: nan
h2: nan
h3: nan
h4: nan
------------------------------------
full
layer: 0
h1: 4.2109375
h2: 4.1484375
h3: 2.220703125
h4: 4.08203125
------------------------------------
layer: 1
h1: 19.75
h2: 19.75
h3: 18.328125
h4: 753.0
------------------------------------
layer: 2
h1: 2.8671875
h2: 2.025390625
h3: 2.2421875
h4: 4.1640625
------------------------------------
layer: 3
h1: 2.259765625
h2: 1.11328125
h3: 1.484375
h4: 0.423583984375
------------------------------------
layer: 4
h1: 4.4375
h2: 4.38671875
h3: 2.642578125
h4: 7.58203125
------------------------------------
layer: 5
h1: 3.142578125
h2: 2.416015625
h3: 2.431640625
h4: 1.2734375
------------------------------------
layer: 6
h1: 2.7265625
h2: 1.98828125
h3: 2.2578125
h4: 1.0556640625
------------------------------------
layer: 7
h1: 2.650390625
h2: 1.8046875
h3: 2.349609375
h4: 1.3251953125
------------------------------------
layer: 8
h1: 3.34375
h2: 1.76171875
h3: 2.6796875
h4: 1.8408203125
------------------------------------
layer: 9
h1: 4.37109375
h2: 2.328125
h3: 3.142578125
h4: 1.734375
------------------------------------
layer: 10
h1: 4.3046875
h2: 3.1796875
h3: 2.62109375
h4: 1.3212890625
------------------------------------
layer: 11
h1: 3.853515625
h2: 3.5078125
h3: 3.0078125
h4: 1.62890625
------------------------------------
layer: 12
h1: 3.33203125
h2: 2.224609375
h3: 2.548828125
h4: 1.005859375
------------------------------------
layer: 13
h1: 3.560546875
h2: 2.783203125
h3: 3.087890625
h4: 2.23828125
------------------------------------
layer: 14
h1: 3.841796875
h2: 2.9609375
h3: 2.63671875
h4: 0.67626953125
------------------------------------
layer: 15
h1: 3.609375
h2: 3.4765625
h3: 3.107421875
h4: 1.8994140625
------------------------------------
layer: 16
h1: 5.06640625
h2: 4.078125
h3: 4.28515625
h4: 5.421875
------------------------------------
layer: 17
h1: 5.35546875
h2: 5.33203125
h3: 3.740234375
h4: 2.919921875
------------------------------------
layer: 18
h1: 4.0
h2: 3.853515625
h3: 3.60546875
h4: 3.271484375
------------------------------------
layer: 19
h1: 4.46484375
h2: 4.4140625
h3: 3.75
h4: 4.16796875
------------------------------------
layer: 20
h1: 3.66796875
h2: 2.970703125
h3: 3.658203125
h4: 2.962890625
------------------------------------
layer: 21
h1: 5.34375
h2: 5.31640625
h3: 3.400390625
h4: 2.0234375
------------------------------------
layer: 22
h1: 3.318359375
h2: 3.203125
h3: 3.451171875
h4: 1.546875
------------------------------------
layer: 23
h1: 4.28125
h2: 4.22265625
h3: 4.109375
h4: 2.67578125
------------------------------------
layer: 24
h1: 4.21484375
h2: 3.220703125
h3: 3.15625
h4: 0.9482421875
------------------------------------
layer: 25
h1: 3.93359375
h2: 3.7109375
h3: 3.947265625
h4: 3.9921875
------------------------------------
layer: 26
h1: 4.3359375
h2: 3.865234375
h3: 4.37109375
h4: 1.7041015625
------------------------------------
layer: 27
h1: 4.55078125
h2: 3.400390625
h3: 3.630859375
h4: 1.4111328125
------------------------------------
layer: 28
h1: 4.90234375
h2: 4.4453125
h3: 7.54296875
h4: 2.0546875
------------------------------------
layer: 29
h1: 13.234375
h2: 4.5
h3: 18.15625
h4: 57.28125
------------------------------------
layer: 30
h1: 255.875
h2: 255.875
h3: 261.75
h4: inf
------------------------------------
layer: 31
h1: nan
h2: nan
h3: nan
h4: nan
------------------------------------
After taking a further look, this doesn't seem to relate any bug but just the limitation of using fp16, and this is also depending on the input data.
One observation I found is: larger tensor values tend to appear when the prompt is (very) short.
Also, when this happens, I often see many places in the corresponding multiplications have values with the same sign.
Nothing more I can provide I am afraid.
Thanks a lot @ydshieh. Did you notice any difference with whether rows are fully masked in the attention mask or not?
We can probably close this one - at least it is good to know that (at least) llama 7b has numerical instabilities during inference in fp16.
whether rows are fully masked in the attention mask or not?
Oh, I might made a mistake! You have max_length=9
in the code snippet, so if I use long sequence, there is no padding!
OK, need to recheck !
I think beam search with ROPE and fp16 has instabilities yes, reported here: #26332 if I am not mistaken this is what we have no? And I think a recent PR to fix this was merged: #26843 . But yeah I have a pretty huge list of bugs to process!
FYI: here the issue is not even in the generation - the issue comes already in the first step: just encoding the input prompt.
Same issue in layer 29/30 in https://github.com/PanQiWei/AutoGPTQ/issues/412. Unmasking fully masked padding rows solves the issue there as well.
And the nans indeed start to appear at the padding index if we do not unmask:
In the layer 30 without unmasking:
hidden_states after layernorm torch.Size([2, 6, 4096])
hidden_states b=0, seq_idx=0 mean: 0.00121307373046875
hidden_states b=0, seq_idx=1 mean: -0.0168914794921875
hidden_states b=0, seq_idx=2 mean: -0.00237274169921875
hidden_states b=0, seq_idx=3 mean: 0.0007181167602539062
hidden_states b=0, seq_idx=4 mean: -0.0108642578125
hidden_states b=0, seq_idx=5 mean: -0.006961822509765625
hidden_states b=1, seq_idx=0 mean: -0.0016736984252929688
hidden_states b=1, seq_idx=1 mean: 0.0012159347534179688
hidden_states b=1, seq_idx=2 mean: -0.016876220703125
hidden_states b=1, seq_idx=3 mean: -0.0023746490478515625
hidden_states b=1, seq_idx=4 mean: 0.0006799697875976562
hidden_states b=1, seq_idx=5 mean: -0.010833740234375
up_proj, down_proj
--- forward
input finite tensor(True, device='cuda:0')
output torch.Size([2, 6, 11008])
output finite tensor(True, device='cuda:0')
output absmax tensor(1.0762e+02, device='cuda:0', dtype=torch.float16)
output absmean tensor(4.6924e-01, device='cuda:0', dtype=torch.float16)
--- forward
input finite tensor(True, device='cuda:0')
output torch.Size([2, 6, 11008])
output finite tensor(True, device='cuda:0')
output absmax tensor(1.0962e+02, device='cuda:0', dtype=torch.float16)
output absmean tensor(4.5728e-01, device='cuda:0', dtype=torch.float16)
gate_proj b=0, seq_idx=0 mean: -0.047821, absmax: 14.078125
gate_proj b=0, seq_idx=1 mean: -0.208618, absmax: 23.078125
gate_proj b=0, seq_idx=2 mean: -0.253174, absmax: 23.859375
gate_proj b=0, seq_idx=3 mean: -0.270264, absmax: 27.84375
gate_proj b=0, seq_idx=4 mean: -0.184692, absmax: 14.5078125
gate_proj b=0, seq_idx=5 mean: -0.254639, absmax: 12.8203125
gate_proj b=1, seq_idx=0 mean: 0.309814, absmax: 107.625
gate_proj b=1, seq_idx=1 mean: -0.047852, absmax: 14.078125
gate_proj b=1, seq_idx=2 mean: -0.208496, absmax: 23.234375
gate_proj b=1, seq_idx=3 mean: -0.252930, absmax: 23.96875
gate_proj b=1, seq_idx=4 mean: -0.270508, absmax: 27.984375
gate_proj b=1, seq_idx=5 mean: -0.184937, absmax: 14.6484375
up_proj b=0, seq_idx=0 mean: 0.001290, absmax: 15.0546875
up_proj b=0, seq_idx=1 mean: -0.008339, absmax: 18.40625
up_proj b=0, seq_idx=2 mean: -0.016205, absmax: 18.0
up_proj b=0, seq_idx=3 mean: -0.005768, absmax: 23.234375
up_proj b=0, seq_idx=4 mean: -0.000823, absmax: 6.44921875
up_proj b=0, seq_idx=5 mean: -0.003519, absmax: 11.6171875
up_proj b=1, seq_idx=0 mean: 0.015915, absmax: 109.625
up_proj b=1, seq_idx=1 mean: 0.001284, absmax: 15.046875
up_proj b=1, seq_idx=2 mean: -0.008362, absmax: 18.5625
up_proj b=1, seq_idx=3 mean: -0.016220, absmax: 18.046875
up_proj b=1, seq_idx=4 mean: -0.005787, absmax: 23.34375
up_proj b=1, seq_idx=5 mean: -0.000838, absmax: 6.546875
act_gate b=0, seq_idx=0 mean: -0.011940, absmax: 14.078125
act_gate b=0, seq_idx=1 mean: 0.004330, absmax: 4.80859375
act_gate b=0, seq_idx=2 mean: 0.010277, absmax: 5.859375
act_gate b=0, seq_idx=3 mean: -0.015503, absmax: 6.46875
act_gate b=0, seq_idx=4 mean: 0.031921, absmax: 5.67578125
act_gate b=0, seq_idx=5 mean: -0.006973, absmax: 6.5
act_gate b=1, seq_idx=0 mean: 0.219971, absmax: 107.625
act_gate b=1, seq_idx=1 mean: -0.011948, absmax: 14.078125
act_gate b=1, seq_idx=2 mean: 0.004345, absmax: 4.80859375
act_gate b=1, seq_idx=3 mean: 0.010429, absmax: 5.859375
act_gate b=1, seq_idx=4 mean: -0.015495, absmax: 6.46484375
act_gate b=1, seq_idx=5 mean: 0.031738, absmax: 5.67578125
inter b=0, seq_idx=0 mean: 0.03338623046875, absmax: 212.0
inter b=0, seq_idx=1 mean: 0.00040793418884277344, absmax: 6.7734375
inter b=0, seq_idx=2 mean: 0.0011510848999023438, absmax: 7.125
inter b=0, seq_idx=3 mean: 0.00832366943359375, absmax: 17.46875
inter b=0, seq_idx=4 mean: 0.00707244873046875, absmax: 13.90625
inter b=0, seq_idx=5 mean: 0.0014142990112304688, absmax: 7.62890625
inter b=1, seq_idx=0 mean: 1.3212890625, absmax: 11800.0
inter b=1, seq_idx=1 mean: 0.03338623046875, absmax: 211.875
inter b=1, seq_idx=2 mean: 0.0004088878631591797, absmax: 6.796875
inter b=1, seq_idx=3 mean: 0.0011835098266601562, absmax: 7.1484375
inter b=1, seq_idx=4 mean: 0.008331298828125, absmax: 17.515625
inter b=1, seq_idx=5 mean: 0.007049560546875, absmax: 13.8828125
call down_proj
--- forward
input finite tensor(True, device='cuda:0')
output torch.Size([2, 6, 4096])
output finite tensor(False, device='cuda:0')
output absmax tensor(inf, device='cuda:0', dtype=torch.float16)
output absmean tensor(inf, device='cuda:0', dtype=torch.float16)
down_proj b=0, seq_idx=0 finite: True
down_proj b=0, seq_idx=1 finite: True
down_proj b=0, seq_idx=2 finite: True
down_proj b=0, seq_idx=3 finite: True
down_proj b=0, seq_idx=4 finite: True
down_proj b=0, seq_idx=5 finite: True
down_proj b=1, seq_idx=0 finite: False
down_proj b=1, seq_idx=1 finite: True
down_proj b=1, seq_idx=2 finite: True
down_proj b=1, seq_idx=3 finite: True
down_proj b=1, seq_idx=4 finite: True
down_proj b=1, seq_idx=5 finite: True
In the layer 30 with unmasking fully masked rows:
hidden_states after layernorm torch.Size([2, 6, 4096])
hidden_states b=0, seq_idx=0 mean: 0.0012102127075195312
hidden_states b=0, seq_idx=1 mean: -0.01690673828125
hidden_states b=0, seq_idx=2 mean: -0.002384185791015625
hidden_states b=0, seq_idx=3 mean: 0.0007028579711914062
hidden_states b=0, seq_idx=4 mean: -0.01085662841796875
hidden_states b=0, seq_idx=5 mean: -0.006946563720703125
hidden_states b=1, seq_idx=0 mean: -0.0006947517395019531
hidden_states b=1, seq_idx=1 mean: 0.00121307373046875
hidden_states b=1, seq_idx=2 mean: -0.0168609619140625
hidden_states b=1, seq_idx=3 mean: -0.0023975372314453125
hidden_states b=1, seq_idx=4 mean: 0.0006928443908691406
hidden_states b=1, seq_idx=5 mean: -0.01084136962890625
up_proj, down_proj
--- forward
input finite tensor(True, device='cuda:0')
output torch.Size([2, 6, 11008])
output finite tensor(True, device='cuda:0')
output absmax tensor(3.3969e+01, device='cuda:0', dtype=torch.float16)
output absmean tensor(4.5752e-01, device='cuda:0', dtype=torch.float16)
--- forward
input finite tensor(True, device='cuda:0')
output torch.Size([2, 6, 11008])
output finite tensor(True, device='cuda:0')
output absmax tensor(3.1141e+01, device='cuda:0', dtype=torch.float16)
output absmean tensor(4.5410e-01, device='cuda:0', dtype=torch.float16)
gate_proj b=0, seq_idx=0 mean: -0.047882, absmax: 14.078125
gate_proj b=0, seq_idx=1 mean: -0.208374, absmax: 23.09375
gate_proj b=0, seq_idx=2 mean: -0.252930, absmax: 23.875
gate_proj b=0, seq_idx=3 mean: -0.270508, absmax: 27.90625
gate_proj b=0, seq_idx=4 mean: -0.184692, absmax: 14.515625
gate_proj b=0, seq_idx=5 mean: -0.254639, absmax: 12.84375
gate_proj b=1, seq_idx=0 mean: -0.073853, absmax: 33.96875
gate_proj b=1, seq_idx=1 mean: -0.047852, absmax: 14.1015625
gate_proj b=1, seq_idx=2 mean: -0.208496, absmax: 23.21875
gate_proj b=1, seq_idx=3 mean: -0.253418, absmax: 23.953125
gate_proj b=1, seq_idx=4 mean: -0.270264, absmax: 27.984375
gate_proj b=1, seq_idx=5 mean: -0.184692, absmax: 14.5546875
up_proj b=0, seq_idx=0 mean: 0.001290, absmax: 15.046875
up_proj b=0, seq_idx=1 mean: -0.008347, absmax: 18.40625
up_proj b=0, seq_idx=2 mean: -0.016235, absmax: 17.984375
up_proj b=0, seq_idx=3 mean: -0.005745, absmax: 23.265625
up_proj b=0, seq_idx=4 mean: -0.000815, absmax: 6.4453125
up_proj b=0, seq_idx=5 mean: -0.003561, absmax: 11.6328125
up_proj b=1, seq_idx=0 mean: -0.004223, absmax: 31.140625
up_proj b=1, seq_idx=1 mean: 0.001290, absmax: 15.078125
up_proj b=1, seq_idx=2 mean: -0.008362, absmax: 18.5625
up_proj b=1, seq_idx=3 mean: -0.016251, absmax: 18.03125
up_proj b=1, seq_idx=4 mean: -0.005783, absmax: 23.328125
up_proj b=1, seq_idx=5 mean: -0.000843, absmax: 6.4765625
act_gate b=0, seq_idx=0 mean: -0.011971, absmax: 14.078125
act_gate b=0, seq_idx=1 mean: 0.004372, absmax: 4.8046875
act_gate b=0, seq_idx=2 mean: 0.010483, absmax: 5.86328125
act_gate b=0, seq_idx=3 mean: -0.015427, absmax: 6.46875
act_gate b=0, seq_idx=4 mean: 0.031860, absmax: 5.67578125
act_gate b=0, seq_idx=5 mean: -0.007015, absmax: 6.4921875
act_gate b=1, seq_idx=0 mean: 0.002026, absmax: 4.19140625
act_gate b=1, seq_idx=1 mean: -0.011955, absmax: 14.1015625
act_gate b=1, seq_idx=2 mean: 0.004314, absmax: 4.8125
act_gate b=1, seq_idx=3 mean: 0.010254, absmax: 5.86328125
act_gate b=1, seq_idx=4 mean: -0.015503, absmax: 6.4609375
act_gate b=1, seq_idx=5 mean: 0.031891, absmax: 5.6640625
inter b=0, seq_idx=0 mean: 0.033355712890625, absmax: 211.875
inter b=0, seq_idx=1 mean: 0.00041985511779785156, absmax: 6.76953125
inter b=0, seq_idx=2 mean: 0.0011568069458007812, absmax: 7.1328125
inter b=0, seq_idx=3 mean: 0.008331298828125, absmax: 17.421875
inter b=0, seq_idx=4 mean: 0.007068634033203125, absmax: 13.8828125
inter b=0, seq_idx=5 mean: 0.0014171600341796875, absmax: 7.63671875
inter b=1, seq_idx=0 mean: 0.0037746429443359375, absmax: 21.890625
inter b=1, seq_idx=1 mean: 0.033477783203125, absmax: 212.625
inter b=1, seq_idx=2 mean: 0.00041794776916503906, absmax: 6.78125
inter b=1, seq_idx=3 mean: 0.001155853271484375, absmax: 7.1328125
inter b=1, seq_idx=4 mean: 0.00830078125, absmax: 17.4375
inter b=1, seq_idx=5 mean: 0.007068634033203125, absmax: 13.828125
call down_proj
--- forward
input finite tensor(True, device='cuda:0')
output torch.Size([2, 6, 4096])
output finite tensor(True, device='cuda:0')
output absmax tensor(5.3750e+02, device='cuda:0', dtype=torch.float16)
output absmean tensor(4.9854e-01, device='cuda:0', dtype=torch.float16)
down_proj b=0, seq_idx=0 finite: True
down_proj b=0, seq_idx=1 finite: True
down_proj b=0, seq_idx=2 finite: True
down_proj b=0, seq_idx=3 finite: True
down_proj b=0, seq_idx=4 finite: True
down_proj b=0, seq_idx=5 finite: True
down_proj b=1, seq_idx=0 finite: True
down_proj b=1, seq_idx=1 finite: True
down_proj b=1, seq_idx=2 finite: True
down_proj b=1, seq_idx=3 finite: True
down_proj b=1, seq_idx=4 finite: True
down_proj b=1, seq_idx=5 finite: True
It is unclear to me what is happening here and how it relates to fully masked rows.
Great details! I am thinking if maybe the original training saw the unmasked row but now at inference time, it saw another version, which leads to this large value now. (similar to the different behavior of SDPA between torch 2.0.1 / 2.1.0 on GPU as we saw previously.)
@ydshieh I want to give a try at some point to the original llama repo to see how padding is handled there.
not stale
mark
I think computing ROPE in float32 percision should partly fix this
I'll mark this as closed, because llama now computes rope in float32! 🥳 Feel free to ping me if you feel like this should not be closed