Different batch sizes lead to different inference results
Hi,
I found that when setting load_in_8bit=True, different batch sizes will lead to very different results, even if I'm doing inference-only. I found this phenomenon for several HF pretrained language models with int8.
A simple example is as follow, where I got very different results when comparing out1 and out2.
Thank you!
GPU: 1 RTX3090, Driver version: 470.103.01
CUDA SETUP: Highest compute capability among GPUs detected: 8.6
CUDA SETUP: Detected CUDA version 114
from transformers import GPT2Tokenizer, AutoModelForCausalLM
import torch
tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-1.3b")
model = AutoModelForCausalLM.from_pretrained("facebook/opt-1.3b",
device_map='auto', load_in_8bit=True)
#model.cuda()
model.eval()
@torch.no_grad()
def do_inference(model, input_ids, attention_mask):
outputs = model(input_ids=input_ids.cuda(), attention_mask=attention_mask.cuda())
return outputs.logits.cpu()
batch_sents = [
'Review: luminous interviews and amazingly evocative film from three decades ago \nSentiment:',
'Review: with fewer gags to break the tedium \nSentiment:',
'Review: aims for poetry and ends up sounding like satire \nSentiment:',
'Review: no way original \nSentiment:'
]
enc_inputs = tokenizer(batch_sents, return_tensors='pt', padding=True)
# run inference with batch_size = 2
out1 = []
for i in range(0, len(batch_sents), 2):
out = do_inference(model, enc_inputs['input_ids'][i:i+2], enc_inputs['attention_mask'][i:i+2])
out1.append(out)
out1 = torch.cat(out1)
# run inference with batch_size = 4
out2 = do_inference(model, enc_inputs['input_ids'], enc_inputs['attention_mask'])
print(torch.abs(out1-out2).max()) #got tensor(2.0664, dtype=torch.float16) on my machine
A separate thing is, probably not a issue but a pitfall, if I uncomment model.cuda(), all the outputs.logits will become nan. Using model.to(torch.device("cuda")) or model.to("cuda") instead will not cause nan.
Thank you for the detailed issue. This was easy to replicate on my machine.
This bug does not occur with int8_threshold=0.0 which indicates the bug is likely some issue with outdated scaling constants that are used in the fp16 decomposition. I will look at this further in the coming days.
Excellent catch, thank you for reporting this!
Hi Tim,
Thank you so much for your reply.
I'd like to share my new findings. It seems like the results are dependent on the instances in the same batch:
# To avoid any potential issues in attention_mask, truncate all sents to the same length
enc_inputs = tokenizer(batch_sents, return_tensors='pt', max_length=5, truncation=True)
# run inference with batch_size = 2
out1 = []
for i in range(0, len(batch_sents), 2):
out = do_inference(model, enc_inputs['input_ids'][i:i+2], enc_inputs['attention_mask'][i:i+2])
out1.append(out)
out1 = torch.cat(out1)
# run inference with batch_size = 2 in a shuffled order
shuffled_order = [0,3,1,2]
input_ids_shf, attn_mask_shf = enc_inputs['input_ids'][shuffled_order], enc_inputs['attention_mask'][shuffled_order]
out4 = []
for i in range(0, len(batch_sents), 2):
out = do_inference(model, input_ids_shf[i:i+2], attn_mask_shf[i:i+2])
out4.append(out)
out4 = torch.cat(out4)
# shuffle back
ret_order = [0,2,3,1] # argsort(shuffled_order)
out4 = out4[ret_order]
print(torch.abs(out1-out4).max())
# got 1.0781 when load_in_8bit=True with the default threshold
# got 0 when load_in_8bit=True and int8_threshold = 0
# got 0 when load_in_8bit=False
A clarifying question:
I wonder when I set int8_threshold=0, is it equivalent to the entire model in fp16 or in int8?
My understanding is: hidden states values that are above this threshold are considered outliers and their operations will be done in fp16, so operation-wise, int8_threshold=0 is equivalent to the entire model in fp16, correct?
Thank you!
Today I also discovered this issue for a GPTJ model when doing greedy decoding for batch sizes of 8 vs 16. I am glad to have confirmation that this is a known issue.
Thanks for the work done thus far!
I looked into it even more and even without using int8, different batch sizes give different results.
Hey! Thanks for the heads-up! I think this is expected, as for half-precision models (especially fp16 rather than bf16) predictions can be quite unstable. Do you know if the lm_head is also in half-precision? Probably casting the head in fp32 could help but not sure ..
I do not know what is expected behavior after seeing this occur without using int8. When I was doing batch processing for GPTJ, I was using bfloat16, which is not unstable like fp16 can be. I have not tried this with fp32 but bfloat16 should be a drop in replacement.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.