text-generation-inference
text-generation-inference copied to clipboard
Top-p sampling returns incorrect/missing log-probabilities in details
System Info
I noticed text-generation-inference is providing incorrect log-probabilities in the details whenever top-p sampling is used. I'm running text-generation-inference via docker, using this k8s pod definition on a single NVIDIA A100. I believe it should be the latest image, I pulled for the first time today:
containers:
- name: bking2-tgi-server
image: ghcr.io/huggingface/text-generation-inference:1.4
args:
- "--model-id"
- "bigcode/starcoder"
- "--quantize"
- "bitsandbytes"
- "-p"
- "21042"
- "--max-input-length"
- "4095"
- "--max-total-tokens"
- "4096"
- "--huggingface-hub-cache"
- "/data/users/bking2/.cache/huggingface"
When calling with top-p, a number of the log-probabilities returned are exactly 0.0 and most or all of the rest are incorrect. Example call (details in Reproduction):
result = client.text_generation(prompt=p, max_new_tokens=128, stop_sequences=[')'], details=True,
top_p=0.7, do_sample=True, best_of=2, decoder_input_details=True)
Interestingly, full sampling does not seem to have this issue.
Information
- [X] Docker
- [ ] The CLI directly
Tasks
- [X] An officially supported command
- [ ] My own modifications
Reproduction
Here's a minimum working example which reproduced the error for me:
from huggingface_hub import InferenceClient
prompts = [
"def hello_world():\n print('Hello,",
"def fibonacci(n):\n if n <= 1:\n return n\n else:\n return (",
"def factorial(n):\n if n == 1:\n return n\n else:\n return (n *",
]
if __name__ == '__main__':
client = InferenceClient(model="http://0.0.0.0:21042")
for p in prompts:
# Greedy Decode
result = client.text_generation(prompt=p, max_new_tokens=128, stop_sequences=[')'], details=True,
decoder_input_details=True)
print("========= Greedy Decode ==========")
print("Prompt:", p)
print("Prompt IDs:", [t.id for t in result.details.prefill])
print('Completion', repr(result.generated_text))
print('Completion IDs:', [t.id for t in result.details.tokens])
print("Log Probs:", [t.logprob for t in result.details.tokens])
print("total Log Prob:", sum(t.logprob for t in result.details.tokens))
greedy_completion = result.generated_text
# Top-p Decode
print("========= Top-p Decode ==========")
result = client.text_generation(prompt=p, max_new_tokens=128, stop_sequences=[')'], details=True,
top_p=0.7, do_sample=True, best_of=2, decoder_input_details=True)
print('Completion', repr(result.generated_text))
print('Completion IDs:', [t.id for t in result.details.tokens])
assert result.generated_text == greedy_completion
print("Log Probs:", [t.logprob for t in result.details.tokens])
print("total Log Prob:", sum(t.logprob for t in result.details.tokens))
print('\n\n')
Which produced this output:
========= Greedy Decode ==========
Prompt: def hello_world():
print('Hello,
Prompt IDs: [589, 17964, 81, 5860, 2262, 664, 1459, 463, 8279, 30]
Completion " World!')"
Completion IDs: [10896, 22735]
Log Probs: [-0.4909668, -0.14099121]
total Log Prob: -0.63195801
========= Top-p Decode ==========
Completion " World!')"
Completion IDs: [10896, 22735]
Log Probs: [-0.3918457, 0.0]
total Log Prob: -0.3918457
========= Greedy Decode ==========
Prompt: def fibonacci(n):
if n <= 1:
return n
else:
return (
Prompt IDs: [589, 28176, 34682, 26, 96, 711, 284, 415, 310, 2511, 225, 35, 44, 291, 442, 310, 284, 813, 44, 291, 442, 308]
Completion 'fibonacci(n-1)'
Completion IDs: [26423, 34682, 26, 96, 31, 35, 27]
Log Probs: [-0.037200928, -0.003162384, -0.005218506, -0.0017595291, -0.5180664, -0.09063721, -0.14355469]
total Log Prob: -0.7995996471000001
========= Top-p Decode ==========
Completion 'fibonacci(n-1)'
Completion IDs: [26423, 34682, 26, 96, 31, 35, 27]
Log Probs: [0.0, 0.0, 0.0, 0.0, -0.5292969, 0.0, 0.0]
total Log Prob: -0.5292969
========= Greedy Decode ==========
Prompt: def factorial(n):
if n == 1:
return n
else:
return (n *
Prompt IDs: [589, 10365, 564, 26, 96, 711, 284, 415, 310, 610, 225, 35, 44, 291, 442, 310, 284, 813, 44, 291, 442, 308, 96, 319]
Completion ' factorial(n-1))'
Completion IDs: [10365, 564, 26, 96, 31, 35, 490]
Log Probs: [-0.018920898, -8.523464e-05, -0.00712204, -0.0030555725, -0.55078125, -0.0008993149, -0.0066070557]
total Log Prob: -0.5874713657399999
========= Top-p Decode ==========
Completion ' factorial(n-1))'
Completion IDs: [10365, 564, 26, 96, 31, 35, 490]
Log Probs: [0.0, 0.0, 0.0, 0.0, -0.5761719, 0.0, 0.0]
total Log Prob: -0.5761719
And logs from the server side:
2024-03-06T22:33:59.309499Z INFO compat_generate{default_return_full_text=true compute_type=Extension(ComputeType("1-nvidia-a100-sxm4-80gb"))}:generate{parameters=GenerateParameters { best_of: None, temperature: None, repetition_penalty: None, frequency_penalty: None, top_k: None, top_p: None, typical_p: None, do_sample: false, max_new_tokens: Some(128), return_full_text: Some(false), stop: [")"], truncate: None, watermark: false, details: true, decoder_input_details: true, seed: None, top_n_tokens: None, grammar: None } total_time="373.282406ms" validation_time="301.4µs" queue_time="58.959µs" inference_time="372.922197ms" time_per_token="186.461098ms" seed="None"}: text_generation_router::server: router/src/server.rs:305: Success
2024-03-06T22:34:00.013313Z INFO compat_generate{default_return_full_text=true compute_type=Extension(ComputeType("1-nvidia-a100-sxm4-80gb"))}:generate{parameters=GenerateParameters { best_of: Some(2), temperature: None, repetition_penalty: None, frequency_penalty: None, top_k: None, top_p: Some(0.7), typical_p: None, do_sample: true, max_new_tokens: Some(128), return_full_text: Some(false), stop: [")"], truncate: None, watermark: false, details: true, decoder_input_details: true, seed: None, top_n_tokens: None, grammar: None } total_time="701.589188ms" validation_time="187.85µs" queue_time="300.879846ms" inference_time="400.521592ms" time_per_token="200.260796ms" seed="Some(8121663308226373158)"}: text_generation_router::server: router/src/server.rs:305: Success
2024-03-06T22:34:00.853564Z INFO compat_generate{default_return_full_text=true compute_type=Extension(ComputeType("1-nvidia-a100-sxm4-80gb"))}:generate{parameters=GenerateParameters { best_of: None, temperature: None, repetition_penalty: None, frequency_penalty: None, top_k: None, top_p: None, typical_p: None, do_sample: false, max_new_tokens: Some(128), return_full_text: Some(false), stop: [")"], truncate: None, watermark: false, details: true, decoder_input_details: true, seed: None, top_n_tokens: None, grammar: None } total_time="838.047278ms" validation_time="246.059µs" queue_time="35.96µs" inference_time="837.765409ms" time_per_token="119.680772ms" seed="None"}: text_generation_router::server: router/src/server.rs:305: Success
2024-03-06T22:34:02.050050Z INFO compat_generate{default_return_full_text=true compute_type=Extension(ComputeType("1-nvidia-a100-sxm4-80gb"))}:generate{parameters=GenerateParameters { best_of: Some(2), temperature: None, repetition_penalty: None, frequency_penalty: None, top_k: None, top_p: Some(0.7), typical_p: None, do_sample: true, max_new_tokens: Some(128), return_full_text: Some(false), stop: [")"], truncate: None, watermark: false, details: true, decoder_input_details: true, seed: None, top_n_tokens: None, grammar: None } total_time="1.194594105s" validation_time="301.579µs" queue_time="268.106817ms" inference_time="926.185909ms" time_per_token="132.312272ms" seed="Some(12980777158106879273)"}: text_generation_router::server: router/src/server.rs:305: Success
2024-03-06T22:34:02.859850Z INFO compat_generate{default_return_full_text=true compute_type=Extension(ComputeType("1-nvidia-a100-sxm4-80gb"))}:generate{parameters=GenerateParameters { best_of: None, temperature: None, repetition_penalty: None, frequency_penalty: None, top_k: None, top_p: None, typical_p: None, do_sample: false, max_new_tokens: Some(128), return_full_text: Some(false), stop: [")"], truncate: None, watermark: false, details: true, decoder_input_details: true, seed: None, top_n_tokens: None, grammar: None } total_time="807.290596ms" validation_time="339.869µs" queue_time="18.93µs" inference_time="806.931947ms" time_per_token="115.275992ms" seed="None"}: text_generation_router::server: router/src/server.rs:305: Success
2024-03-06T22:34:03.967174Z INFO compat_generate{default_return_full_text=true compute_type=Extension(ComputeType("1-nvidia-a100-sxm4-80gb"))}:generate{parameters=GenerateParameters { best_of: Some(2), temperature: None, repetition_penalty: None, frequency_penalty: None, top_k: None, top_p: Some(0.7), typical_p: None, do_sample: true, max_new_tokens: Some(128), return_full_text: Some(false), stop: [")"], truncate: None, watermark: false, details: true, decoder_input_details: true, seed: None, top_n_tokens: None, grammar: None } total_time="1.105341106s" validation_time="284.58µs" queue_time="268.824896ms" inference_time="836.23176ms" time_per_token="119.46168ms" seed="Some(13750410545473914160)"}: text_generation_router::server: router/src/server.rs:305: Success
I verified outside of text-generation-inference that the greedy log probabilities roughly match what one would get via calling the model directly in transformers. Full sampling also matches expectations:
========= Greedy Decode ==========
Prompt: def hello_world():
print('Hello,
Prompt IDs: [589, 17964, 81, 5860, 2262, 664, 1459, 463, 8279, 30]
Completion " World!')"
Completion IDs: [10896, 22735]
Log Probs: [-0.4909668, -0.14099121]
total Log Prob: -0.63195801
========= Full Sampling Decode ==========
Completion " World!')"
Completion IDs: [10896, 22735]
Log Probs: [-0.46899414, -0.14025879]
total Log Prob: -0.60925293
Expected behavior
When sampling with top_p < 1 , log probabilities in details should be nearly equal to when greedy decoding or full sampling.
@kingb12 this is because the logprobs are computed relative to the remaining shortlist of tokens after the top_p cutoff has been applied.
A logprob value of 0 is common because this is when there's only a single token in the top_p of the original probability mass (log(1.0) = 0.0).
Ah I see, that makes sense. I can see both meanings as useful as long as they're documented. If there's a document where it would make sense to include this information I'd be happy to try adding it.
For my use case, I've worked around this by just taking the returned sequence and calling the API again in greedy mode with 1 new token, then getting the log-probabilities from the prefill details with decoder_input_details=True.
This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.