Missing logits in Executor API when using `return_generation_logits`
System Info
- Nvidia A40
- CUDA 12.2
- TensorRT 10.0.1.6
- TensorRT-LLM 0.10.0.dev2024050700
Who can help?
@byshiue
Information
- [X] The official example scripts
- [ ] My own modified scripts
Tasks
- [X] An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - [ ] My own task or dataset (give details below)
Reproduction
When requesting generation logits via the Executor API's Python bindings, in certain cases one entire generation step is missing in the logits tensor. As a result, all of the subsequent generation steps in the logits tensor are shifted by 1 and do not match anymore with the generated tokens.
This appears to happen specifically when the generation loop terminates early due to reaching end_id: if a custom stop sequence is encountered, or if the maximum number of new tokens is reached, the returned logits are correct.
The problem can be reproduced with the steps below:
python convert_checkpoint.py --model_dir ./falcon_7b_tp1_instruct/ --dtype bfloat16 --output_dir ./falcon_7b_tp1_instruct_trt_chkpt
trtllm-build --checkpoint_dir ./falcon_7b_tp1_instruct_trt_chkpt/ --gemm_plugin bfloat16 --remove_input_padding enable --gpt_attention_plugin bfloat16 --output_dir ./falcon_7b_tp1_instruct_p200_g200 --gather_all_token_logits --max_input_len 200 --max_output_len 200 --max_batch_size 64
python example_basic.py --model_path ./falcon_7b_tp1_instruct_p200_g200
The examples/bindings/executor/example_basic.py script was modified to issue a request that exhibits the issue, and to print the arg-max of the logits at each generation step. Below is the modified script:
diff --git a/examples/bindings/executor/example_basic.py b/examples/bindings/executor/example_basic.py
index 2c7a3fc..3f1991f 100644
--- a/examples/bindings/executor/example_basic.py
+++ b/examples/bindings/executor/example_basic.py
@@ -1,4 +1,5 @@
import argparse
+import torch
import tensorrt_llm.bindings.executor as trtllm
@@ -21,8 +22,10 @@ if __name__ == "__main__":
if executor.can_enqueue_requests():
# Create the request.
- request = trtllm.Request(input_token_ids=[1, 2, 3, 4],
- max_new_tokens=10)
+ request = trtllm.Request(input_token_ids=[100, 20, 3, 18],
+ max_new_tokens=20,
+ end_id=25,
+ output_config=trtllm.OutputConfig(return_generation_logits=True))
# Enqueue the request.
request_id = executor.enqueue_request(request)
@@ -30,6 +33,9 @@ if __name__ == "__main__":
# Wait for the new tokens.
responses = executor.await_responses(request_id)
output_tokens = responses[0].result.output_token_ids
+ output_top_tokens = torch.argmax(responses[0].result.generation_logits[0], dim=1).tolist()
+
# Print tokens.
- print(output_tokens)
+ print(f"Output tokens: {output_tokens[0][4:]}")
+ print(f"Logits arg-max: {output_top_tokens}")
Expected behavior
Since we are using top_k=1 and are not sampling tokens, we expect the argmax of the logits at each generation step to match exactly the tokens returned for the request.
actual behavior
The generated tokens and the argmax of the logits do not match, and the latter is missing one entire generation step:
Output tokens: [94, 241, 914, 818, 271, 577, 402, 2862, 271, 1730, 544, 248, 1079, 1111, 612]
Logits arg-max: [94, 241, 914, 818, 271, 577, 402, 2862, 1730, 544, 248, 1079, 1111, 612, 25, 0, 0, 0, 0, 0]
Notice how token 271 is missing toward the end of the logits argmax sequence, and how all subsequent tokens are shifted by 1.
additional notes
The issue was observed on all TensorRT-LLM 0.10 dev versions, up to 0.10.0.dev2024050700.
Thanks for filing this issue @AlessioNetti, I was able to reproduce the bug. Taking a look now.
I think I found the issue. We should be able to get the fix in soon.
Hi @AlessioNetti` do u still have further issue or question now? If not, we'll close it soon.
Hi again - this specific issue was addressed, though there still seem to be minor issues as of version 0.15.0.dev2024110500, specifically regarding generation logits and early stopping due to end_id.
In particular, it seems that since a few versions the logits for the very last generation step (the one that resulted in end_id) are missing from the output tensor. Running the example above results in the following output:
Output tokens : [94, 241, 914, 818, 271, 577, 402, 2862, 271, 1730, 544, 248, 1079, 1111, 612]
Logits arg-max : [94, 241, 914, 818, 271, 577, 402, 2862, 271, 1730, 544, 248, 1079, 1111, 612, 0, 0, 0, 0, 0]
Notice how 25 is missing in the logits arg-max list. The corresponding list element is 0 in this case, but we noticed that it can be any random value depending on what requests have been issued previously to the Executor.
Not sure if this is intended behavior - for the moment, specifying the end_id as a stop word appears to be an effective workaround.
@trevor-m , could you please review @AlessioNetti's feedback?
Hi @AlessioNetti, it looks like we intentionally changed this so that the logits length will match the tokens. We will consider whether it should be added back.