Flan t5 xxl result large difference
System Info
GPU: Nvidia a10g, 1 g5.12xlarge instance
Who can help?
@byshiue @symphonylyh
Information
- [X] The official example scripts
- [X] My own modified scripts
Tasks
- [ ] An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - [X] My own task or dataset (give details below)
Reproduction
- Use the docker container by following:
docker run --rm --runtime=nvidia --GPUs all --entrypoint /bin/bash -it nvidia/cuda:12.1.0-devel-ubuntu22.04
# Install dependencies, TensorRT-LLM requires Python 3.10
apt-get update && apt-get -y install python3.10 python3-pip openmpi-bin libopenmpi-dev
pip3 install tensorrt_llm -U --extra-index-url https://pypi.nvidia.com
pip uninstall -y mpmath
pip install mpmath==1.3.0
- Use the following build command to build engines from public flan-t5-XXL:
python TensorRT-LLM/examples/enc_dec/t5/convert.py -i google/flan-t5-xxl -o /public_t5_trt_covert_official/ --weight_data_type float32 --inference_tensor_para_size 4
python TensorRT-LLM/examples/enc_dec/build.py --model_type t5 --world_size 4 --tp_size 4 --gpus_per_node 4 --weight_dir /public_t5_trt_covert_official/tp4 -o /public_t5_trt_engine_official --engine_name t5 --use_bert_attention_plugin --use_gpt_attention_plugin --use_gemm_plugin --dtype bfloat16 --max_batch_size 32 --max_encoder_input_len 128 --max_output_len 128 --parallel_build
- Call
examples/enc_dec/run.py(modified) inside the container as follows:
mpirun --allow-run-as-root -np 4 python TensorRT-LLM/examples/enc_dec/run_modified.py --engine_dir /public_t5_trt_engine_official --engine_name t5 --model_name /fluency_model/ --max_new_token=128 --num_beams=1 --compare_hf_fp32
I slightly modified the run.py to compare with HF bfloat16 results and use one example prompt. To replicate, just replace this part after if __name__ == "__main__":
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
args = parse_arguments()
logger.set_level(args.log_level)
# FairSeq NMT test logic is different from HuggingFace models
if 'wmt' in args.model_name:
test_fairseq_models(args)
exit()
test_remove_padding = True
if not test_remove_padding:
if 't5' in args.model_name:
input_text = "translate English to German: The house is wonderful, radiating timeless charm and offering a warm, inviting interior with beautiful details and a serene backyard."
elif 'bart' in args.model_name:
input_text = "The tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building, and the tallest structure in Paris. Its base is square, measuring 125 metres (410 ft) on each side. During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world, a title it held for 41 years until the Chrysler Building in New York City was finished in 1930. It was the first structure to reach a height of 300 metres. Due to the addition of a broadcasting aerial at the top of the tower in 1957, it is now taller than the Chrysler Building by 5.2 metres (17 ft). Excluding transmitters, the Eiffel Tower is the second tallest free-standing structure in France after the Millau Viaduct."
else:
raise RuntimeError('Unsupported model type!')
else:
input_text = [
"Keeping the Secret of Genetic Testing",
# "translate English to German: The house is wonderful.",
# "summarize: I am a high-performance inference optimizer and runtime.",
# "During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world",
]
tokenizer = AutoTokenizer.from_pretrained(args.model_name) # TODO: use model path instead
tokenized_inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=128)
max_new_tokens = args.max_new_tokens
input_ids = tokenized_inputs.input_ids.type(torch.IntTensor).to(
'cuda') # [batch_size, padded_length]
# by default int64, must cast to int32! otherwise C++ kernel will interpret as [a, 0, b, 0, c, 0, ...]
if tensorrt_llm.mpi_rank() == 0:
print("--------------------------------------")
print(
f"BOS={tokenizer.bos_token_id}, PAD={tokenizer.pad_token_id}, EOS={tokenizer.eos_token_id}"
)
print("input text: ", input_text)
print("input ids: ", input_ids)
print("input lengths: ", tokenized_inputs.attention_mask.sum(dim=1))
print("--------------------------------------")
model_config = AutoConfig.from_pretrained(args.model_name)
# start_id for decoder (could add more input_ids as forced_decoder_ids)
decoder_input_ids = torch.IntTensor([[model_config.decoder_start_token_id]
]).to('cuda')
decoder_input_ids = decoder_input_ids.repeat((input_ids.shape[0], 1))
if tensorrt_llm.mpi_rank() == 0:
print("Starting comparing with hf bfp16")
# simple comparison with HF on FP32
if args.compare_hf_fp32:
if tensorrt_llm.mpi_rank() == 0:
hf_model = AutoModelForSeq2SeqLM.from_pretrained(
args.model_name, # TODO: use model path instead
device_map="balanced_low_0",
torch_dtype=torch.bfloat16
# torch_dtype=torch.float16 if '16' in dtype else torch.float32, # TODO: use matched torch dtype
).eval() # TODO: create config model path instead
assert type(hf_model) in (
T5ForConditionalGeneration, BartForConditionalGeneration,
MBartForConditionalGeneration), 'Unsupported model!'
tik = time.time()
# breakpoint()
hf_gen_output = hf_model.generate(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
max_new_tokens=max_new_tokens,
num_beams=args.num_beams,
bos_token_id=tokenizer.bos_token_id,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
use_cache=True,
# control logits processors
no_repeat_ngram_size=0, # disable no repeat post-processor
forced_bos_token_id=None, # disable forced first/last token
forced_eos_token_id=None,
min_length=0,
# for debug
output_scores=True,
output_hidden_states=True,
return_dict_in_generate=True)
# get hf output scores
hf_output_ids = hf_gen_output.sequences
# convert to logits
torch.cuda.synchronize()
tok = time.time()
output_ids = hf_output_ids.squeeze(dim=1)
hf_output_text = tokenizer.batch_decode(output_ids,
skip_special_tokens=True)
decoder_input_lengths = (decoder_input_ids !=
tokenizer.pad_token_id).sum(dim=1)
output_gen_lengths = (output_ids != tokenizer.eos_token_id).sum(
dim=1) - decoder_input_lengths
print("--------------------------------------")
print("HF output_ids: ", output_ids)
print("HF output text: ", hf_output_text)
print("HF output generated lengths: ", output_gen_lengths)
print(f"HF E2E time {(tok-tik)*1000}ms")
print("--------------------------------------")
# Clean cache
del hf_model
gc.collect()
torch.cuda.empty_cache()
if tensorrt_llm.mpi_rank() == 0:
print("Done with HF inference")
# print(torch.cuda.memory_summary())
# TRT-LLM runtime
tllm_model = TRTLLMEncDecModel.from_engine(args.engine_name,
args.engine_dir,
debug_mode=args.debug_mode)
tik = time.time()
tllm_output_ids = tllm_model.generate(
encoder_input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
max_new_tokens=max_new_tokens,
num_beams=args.num_beams,
bos_token_id=tokenizer.bos_token_id,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
debug_mode=args.debug_mode,
return_dict=False, # when set return_dict=True, get outputs by key
attention_mask=tokenized_inputs.attention_mask)
tok = time.time()
inference_dtype = tllm_model.encoder_model_config.dtype
if tensorrt_llm.mpi_rank() == 0:
output_ids = tllm_output_ids[:, 0, :]
output_text = tokenizer.batch_decode(output_ids,
skip_special_tokens=True)
decoder_input_lengths = (decoder_input_ids !=
tokenizer.pad_token_id).sum(dim=1)
output_gen_lengths = (output_ids != tokenizer.eos_token_id).sum(
dim=1) - decoder_input_lengths
print("--------------------------------------")
print("TRT-LLM output_ids: ", output_ids)
print("TRT-LLM output text: ", output_text)
print("TRT-LLM output generated lengths: ", output_gen_lengths)
print(f"TRT-LLM E2E time {(tok-tik)*1000}ms")
print("Precision:", inference_dtype)
print("--------------------------------------")
# simple accuracy check
if args.compare_hf_fp32:
from difflib import SequenceMatcher
match_rate = SequenceMatcher(None, "\n".join(output_text),
"\n".join(hf_output_text)).ratio()
print(output_text)
print(hf_output_text)
if inference_dtype != "float32":
print("")
print(
f"[CAVEAT] Comparing TRT-LLM {inference_dtype} results with HF float32 results. Close match are not expected!"
)
assert match_rate > 0.8, f"Incorrect results! Match rate {match_rate}"
print(
f"TRT-LLM results match HF FP32 results with literal match rate {match_rate}"
)
Expected behavior
HF and TRT LLM results are roughly the same
actual behavior
HF output text: ['Keeping the Secret of Genetic Testing'] TRT-LLM output text: ['Keeping the Secret of Genetic Testing - The New York Times]
There is also a TensorRT error during the running:
[03/24/2024-05:29:38] [TRT] [E] 3: [engine.cpp::getProfileObliviousBindingIndex::1530] Error Code 3: Internal Error (setTensorAddress given invalid tensor name: attention_mask)
additional notes
In a larger dataset, the result difference is very obvious. Using the Fastertransformer can give much closer results with HF.
- We used the TensorRT LLM (version 0.9) framework to build an engine and perform inference testing on the T5 model. We found that the TRT model converted from the encoder-decoder has a large number of samples that do not align with the Torch model (many samples have a matching degree of 0.5).
- We noticed that someone previously posted the same issue enc_dec model results are not aligned with the HF model · Issue #612 · NVIDIA/TensorRT-LLM · GitHub. We are using the latest version of TensorRT LLM, version 0.9, and this issue should have been resolved. However, based on our testing, the alignment issue still persists.
- In response to this, we conducted separate tests for cases where the predictions differ from Torch, and saved the encoder outputs as .npy files. We directly tested the decoder module and found that the output misalignment was caused by the cross-attention. Additionally, another observation is that for the samples that do not align, the initial generated tokens are always similar for the first few steps, and then diverge later on. At present, we are unable to determine what changes in the GPT plugin may have caused this issue. We look forward to your response.
Any update on this issue?
Hi @sc-gr , @drivingchangeworld , @aashsach, Although #612 led to some unrelated discussion till the end, I want to highlight the previous explanation I made based on numerical analysis, buried in the comments of #612: https://github.com/NVIDIA/TensorRT-LLM/issues/612#issuecomment-1847063203, https://github.com/NVIDIA/TensorRT-LLM/issues/612#issuecomment-1851045753. " Conclusion first: this is normal, and it's likely a HF/PyTorch gemm problem.
I was obsessed by the same tiny numerical difference issue during my development of enc-dec too. You're checking the encoder_output tensor which has already gone through some numerical cumulation. I was checking the Q,K,V tensors right after QKV projection. The tiny deviation reaches noticeable decimal difference after a few layers. I wanted to know what's the ground truth, so this is what I did: for FP32, Q = W*X, I saved 1 row in W and 1 column in X as tensors (i.e., two vectors that multiply-add to get one element in Q). (1) use torch.matmul (2) use TRT-LLM w/o gemm plugin (3) use TRT-LLM w/ gemm plugin (4) golden standard -- hand calculation, which I use numpy Findings: (4) == (3) ~= (2) != (1) --> HF/PyTorch is not 100% reliable --> we shouldn't treat HF/PyTorch when it comes to tiny numerical difference. And a side note, FP32 is not guaranteed to have perfect match even, because gemm algorithms selection strategies in each framework are different. Tiny difference will propagate over layers and over sequence length, so it is GUARANTEED to have results gap and the gap is GUARANTEED to enlarge when seqlen becomes longer / more layers i.e. large models. Model accuracy should be evaluated by downstream tasks instead of just numerically. e.g., for the given example above, I believe it can be interpreted as TRT-LLM result quality is better than HF result -- although this doesn't say anything either, LLM generation is not a deterministic process.
A little more explanation on this numerical analysis: Such effect may be more prominent for encoder-decoder models than for decoder-only models, and the reason is cross attention: encoder-decoder model will first run through encoder once, get the encoder output and the encoder KV cache. Then the decoder's cross attention will do matmul between decoder input & encoder KV cache. Based on the above numerical accumulation explanation, you can see that the encoder output & KV cache itself has already accumulated some numerical errors through all the encoder layers. Then the cross attention calculation will inherit that error and again further accumulate through all decoder layers -- so you can see why encoder-decoder is more susceptible to such numerical deviation.
Key takeaway from this is: we should better evaluate on real downstream tasks and see whether & how much such numerical difference affects the output quality, rather than pursuing exact match of logit values. Of course, sometimes it's not easy to conclude whether it's implementation bug or numerical deviation, but so far from our analysis and user feedback we think it's not from implementation bug in TRT-LLM's encoder-decoder models. "
@drivingchangeworld 's debugging effort narrows down to the cross attention, and previously I further narrowed down to just the Q*K gemm in cross attention. For everyone who is blocked by this issue, please comment and see if this gives a valid explanation, or if you have better ideas to further investigate. Thanks!
@symphonylyh Thank you for the detailed response. In my case, outputs of decoder are way off as compared to the HF model. I have tried optimising using TensorRT as well as Fastertransformer and both of them have significantly closer results to HF pytorch model
However, outputs from TRT-LLM don't even come close to any of them. In many cases, it just hallucinates and keeps repeating the same sequence of sentences.
May be I am missing something here.
@aashsach is your case identical to this issue, i.e. also Flan-t5-xxl with TP=4, precision BF16? or some other model, and I'm assuming you're aware that FP16 won't work here
My model is flan t5 xl with tp 1. Yes, I am using bfloat16 and not fp16.
can you send a reproducer? if it's not a fine-tuned model, you can just post your example input, expected output (from HF/FT), and the TRTLLM output you saw
We have encountered a similar issue where the T5-large model fails to align with the HF model. We used the test set provided by HF (question-answer pairs), and found that while some samples align perfectly, others do not. Similarly, we have traced this issue back to the cross-attention mechanism. This problem is challenging to pinpoint because it is not an obvious error.
After reviewing your analysis, we have a few questions:
Firstly, if there were discrepancies in matrix calculations, why is the error in self-attention computation minimal? We believe that cross-attention and self-attention are implemented similarly, with the only difference being the input.
Additionally, we expect the model acceleration to strictly align with the HF model (with minor deviations allowed for a few samples). However, after testing numerous samples, we found that many do not match well, whereas they align well in the HF framework. In the run.py provided in the example, there are three test cases, and the last one does not align well with HF. Have you conducted extensive testing?
We also performed some tests and observed that it is not simply a matter of cumulative numerical errors.
Like the user above, we saved the HF model's Encoder output as the TRT model's Decoder input (to investigate the cumulative error of the Encoder). We noticed that in some problematic samples, the first layer's self-attention aligns well with the HF model in the initial inference (with an error of around 0.001). However, in the cross-attention, the calculation deviation reaches 0.1 (we also compared the cross-past-key-value saved by HF and TRT, with a numerical deviation of only 0.001, which is acceptable). Of course, this issue does not occur in all samples. From this observation, we deduce that it may not be a simple cumulative numerical error but potentially a bug.
We also observed that in the t5-large model, the cross-past-key-value saved by the TRT model sometimes exhibits anomalies (there are occasional all-zero values in the sequence length dimension), which do not align with the HF model (the dimensions of TRT's cross-past-key-value and HF's are the same). This situation does not occur frequently, and we only noticed it by chance. We believe it could be a potential bug.
Finally, due to the elusive nature of this problem, it is challenging to determine its cause.
@0xd8b Thanks much for the extensive analysis! As you said, the issue is indeed subtle and elusive, and I really would like to start debugging on a concrete example so we're on the same page and sharing thoughts.
We noticed that in some problematic samples, the first layer's self-attention aligns well with the HF model in the initial inference (with an error of around 0.001). However, in the cross-attention, the calculation deviation reaches 0.1 (we also compared the cross-past-key-value saved by HF and TRT, with a numerical deviation of only 0.001, which is acceptable)
If you're using a public T5-large model, can you please share one or few sample inputs that can manifest this issue, so I can dig deeper? I will also do regression tests on early TRT-LLM versions, because there might be changes undetected due to an oversimplified accuracy test for enc-dec.
I use t5-large model test some samples and found that the matching scores of some samples were relatively low. I analyzed it in detail and found that starting from step 9 in the generate process, there was a large error in self-attention, and then cross attention further amplified the error, with an error of 10e -2 level, the error in subsequent steps is getting bigger and bigger.
sample1: input text:"translate English to German:You need to answer the question 'Does the given text contains any fact or opinion?', given a piece of text. Your answer must be either Fact or Opinion. If the answer is Fact, that means all statements in the given text are facts and can be proven by evidence. If the answer is Opinion, that means at least one statement in the given text is an opinion which cannot be proven by evidence. Statements that cannot be proven by evidence may still be true, but there is no way to know for certain if they are true or not. The output should be one of two values - 'Fact' or 'Opinion'." hf_output:"Sie müssen die Frage 'Contient-le-text d'une fakt ou d'une opinion?' beantworten, wobei Sie einen Text vorgeben." tensorrt_llm output:"Sie müssen die Frage 'Contient-le-text irgendeine Fakt oder Meinung?' beantworten, gegeben einen Text. Ihre Antwort muss entweder Fakt oder Meinung sein."
sample2: input text:"translate English to German:You are given a list of words. Your task is to find the number of anagrams in the list. Two words are anagrams if they have the same characters but in a different order." hf_output:"Sie erhalten eine Liste von Wörtern, die Sie ausfindig machen müssen, um die Anzahl der Anagramme in der Liste zu finden." tensorrt_llm output:"Sie erhalten eine Liste von Wörtern, die Sie in der Reihenfolge der Wörter in der Liste anagrammieren müssen."
Thank you for your response! We are using the T5-Large model, but we have fine-tuned it, which makes sharing the model and test samples difficult. However, I will try my best to find clear examples of errors to assist in better localization. Currently, I have disabled the gpt-attention-plugin and then reconstructed the engine. The inference results are 100% identical to the HF model, but there is a slight increase in graphics memory usage, and the inference speed is comparable to using the plugin. Is this a normal phenomenon? (Using the plugin normally should result in a significant speed improvement.) Looking forward to your reply!
@0xd8b may I ask what version of TensorRT-LLM you're using? I wanted to try disabling the gpt-attention-plugin too, but I'm using v0.8.0 and I am unable to only disable the gpt-attention-plugin because there is No support for relative attention bias in plain TRT mode. for v0.8.0: https://github.com/NVIDIA/TensorRT-LLM/blob/v0.8.0/examples/enc_dec/build.py#L379
@jerrylinamazon the 0.8.0 version has this check while the latest main code doesn't and you can disable gpt plugin.
@sc-gr We're using the latest code.
@sc-gr , @0xd8b , @jerrylinamazon , @drivingchangeworld , @aashsach ,
TLDR: Great news folks! we found the bug in relative attention bias computation (i.e., affecting T5 family only), there is a one-line change needed in decoderMaskedMultiheadAttentionTemplate.h. After that it'd be almost identical to HF. This change will be released in next Tuesday's Apr 30 weekly update.
Details:
- [1] HF original implementation
- [2] TRT-LLM non-plugin implementation
- [3] TRT-LLM plugin implementation of relative attention bias includes: (a) bias addition in encoder, which is in BertAttention plugin here (b) bias addition in decoder's context phase, which is in GPTAttention plugin here (c) bias additional in decoder's generation phase, which is in GPTAttention plugin, underlying the decoderMaskedMultiheadAttentionTemplate.h here
[1] and [2] are identical, meaning the TRT-LLM non-plugin path is accurate. [3] is different from [1] or [2], and the reason is the bidirectional flag in the relative attention bias computation. The correct implementation is: for encoder, bidirectional=True, for decoder, bidirectional=False. However, for [3], we used to get 3(a) and 3(b) right, but 3(c) was doing bidirectional=True. Now we fixed this to let it follow the bidirectional=False logic. If you want to get the fix sooner, change decoderMaskedMultiheadAttentionTemplate.h from
num_buckets /= 2;
relative_buckets += relative_position > 0 ? num_buckets : 0;
relative_position = abs(relative_position);
to
relative_position = relative_position > 0 ? 0 : -relative_position;
Note: there are two such occurences in the file you need to make this change.
It's indeed very tricky and only manifest when generation step becomes larger and falls in to different relative attention buckets.
Please test on your models & applications, and let me know if this works!
@aashsach yes, both places need to be changed. Let me know!
gotcha... tested on some examples, seems to be working fine now. will update after exhaustive testing
@symphonylyh Thank you very much for your thorough analysis and resolution of the issue. We have modified the code and conducted tests. Currently, with the GPT plugin, the T5 model can align with the HF model. However, when we attempt to convert the model using fp16 type, during testing, we found that when the batch_size is 1 and the GPU utilization is high (100%), it causes the encoder's predicted output to be NAN values. We have identified that the issue may lie in the RMS_norm layer, which is likely a bug.
@0xd8b that looks like fp16 overflow issue, and is probably directly related to your fine-tuned T5 weights. We have observed similar things with custom T5 models before, and the solution before was the customer re-trained their model by applying some guard to control the weights magnitudes under certain threadhold. This may or may not apply in your cases. Alternatively, can you try using FP16 weights --(convert during your HF ckpt export or TRT-LLM weight conversion)-> BF16 weights and see if it works?
Lastly, I would suggest you open a separate issue and ping me there, so we can help with that as a standalone topic.
@symphonylyh I will submit a new issue. This is an interesting phenomenon:
- Using float32 type:
- GPU initial usage is 0%, model inference works correctly, and the inference results are accurate.
- GPU initial usage is 100%, model inference works normally.
- Using float16 type:
- GPU initial usage is 0%, model inference works correctly, and the inference results are accurate.
- GPU initial usage is 100%, model inference results in abnormalities (encoder outputs NaN values).
We truncated the model output. Why does different GPU usage rates lead to overflow in the model output? This is an interesting question.
This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 15 days."