text-generation-inference icon indicating copy to clipboard operation
text-generation-inference copied to clipboard

Top-p sampling returns incorrect/missing log-probabilities in details

Open kingb12 opened this issue 1 year ago • 3 comments

System Info

I noticed text-generation-inference is providing incorrect log-probabilities in the details whenever top-p sampling is used. I'm running text-generation-inference via docker, using this k8s pod definition on a single NVIDIA A100. I believe it should be the latest image, I pulled for the first time today:

       containers:
      - name: bking2-tgi-server
        image: ghcr.io/huggingface/text-generation-inference:1.4
        args:
        - "--model-id"
        - "bigcode/starcoder"
        - "--quantize"
        - "bitsandbytes"
        - "-p"
        - "21042"
        - "--max-input-length"
        - "4095"
        - "--max-total-tokens"
        - "4096"
        - "--huggingface-hub-cache"
        - "/data/users/bking2/.cache/huggingface"

When calling with top-p, a number of the log-probabilities returned are exactly 0.0 and most or all of the rest are incorrect. Example call (details in Reproduction):

 result = client.text_generation(prompt=p, max_new_tokens=128, stop_sequences=[')'], details=True,
                                        top_p=0.7, do_sample=True, best_of=2, decoder_input_details=True)

Interestingly, full sampling does not seem to have this issue.

Information

  • [X] Docker
  • [ ] The CLI directly

Tasks

  • [X] An officially supported command
  • [ ] My own modifications

Reproduction

Here's a minimum working example which reproduced the error for me:

from huggingface_hub import InferenceClient

prompts = [
    "def hello_world():\n   print('Hello,",
    "def fibonacci(n):\n    if n <= 1:\n        return n\n    else:\n        return (",
    "def factorial(n):\n    if n == 1:\n        return n\n    else:\n        return (n *",
]

if __name__ == '__main__':
    client = InferenceClient(model="http://0.0.0.0:21042")
    for p in prompts:
        # Greedy Decode
        result = client.text_generation(prompt=p, max_new_tokens=128, stop_sequences=[')'], details=True,
                                        decoder_input_details=True)
        print("========= Greedy Decode ==========")
        print("Prompt:", p)
        print("Prompt IDs:", [t.id for t in result.details.prefill])
        print('Completion', repr(result.generated_text))
        print('Completion IDs:', [t.id for t in result.details.tokens])
        print("Log Probs:", [t.logprob for t in result.details.tokens])
        print("total Log Prob:", sum(t.logprob for t in result.details.tokens))
        greedy_completion = result.generated_text
        # Top-p Decode
        print("========= Top-p Decode ==========")
        result = client.text_generation(prompt=p, max_new_tokens=128, stop_sequences=[')'], details=True,
                                        top_p=0.7, do_sample=True, best_of=2, decoder_input_details=True)
        print('Completion', repr(result.generated_text))
        print('Completion IDs:', [t.id for t in result.details.tokens])
        assert result.generated_text == greedy_completion
        print("Log Probs:", [t.logprob for t in result.details.tokens])
        print("total Log Prob:", sum(t.logprob for t in result.details.tokens))
        print('\n\n')

Which produced this output:

========= Greedy Decode ==========
Prompt: def hello_world():
   print('Hello,
Prompt IDs: [589, 17964, 81, 5860, 2262, 664, 1459, 463, 8279, 30]
Completion " World!')"
Completion IDs: [10896, 22735]
Log Probs: [-0.4909668, -0.14099121]
total Log Prob: -0.63195801
========= Top-p Decode ==========
Completion " World!')"
Completion IDs: [10896, 22735]
Log Probs: [-0.3918457, 0.0]
total Log Prob: -0.3918457
========= Greedy Decode ==========
Prompt: def fibonacci(n):
    if n <= 1:
        return n
    else:
        return (
Prompt IDs: [589, 28176, 34682, 26, 96, 711, 284, 415, 310, 2511, 225, 35, 44, 291, 442, 310, 284, 813, 44, 291, 442, 308]
Completion 'fibonacci(n-1)'
Completion IDs: [26423, 34682, 26, 96, 31, 35, 27]
Log Probs: [-0.037200928, -0.003162384, -0.005218506, -0.0017595291, -0.5180664, -0.09063721, -0.14355469]
total Log Prob: -0.7995996471000001
========= Top-p Decode ==========
Completion 'fibonacci(n-1)'
Completion IDs: [26423, 34682, 26, 96, 31, 35, 27]
Log Probs: [0.0, 0.0, 0.0, 0.0, -0.5292969, 0.0, 0.0]
total Log Prob: -0.5292969
========= Greedy Decode ==========
Prompt: def factorial(n):
    if n == 1:
        return n
    else:
        return (n *
Prompt IDs: [589, 10365, 564, 26, 96, 711, 284, 415, 310, 610, 225, 35, 44, 291, 442, 310, 284, 813, 44, 291, 442, 308, 96, 319]
Completion ' factorial(n-1))'
Completion IDs: [10365, 564, 26, 96, 31, 35, 490]
Log Probs: [-0.018920898, -8.523464e-05, -0.00712204, -0.0030555725, -0.55078125, -0.0008993149, -0.0066070557]
total Log Prob: -0.5874713657399999
========= Top-p Decode ==========
Completion ' factorial(n-1))'
Completion IDs: [10365, 564, 26, 96, 31, 35, 490]
Log Probs: [0.0, 0.0, 0.0, 0.0, -0.5761719, 0.0, 0.0]
total Log Prob: -0.5761719

And logs from the server side:

2024-03-06T22:33:59.309499Z  INFO compat_generate{default_return_full_text=true compute_type=Extension(ComputeType("1-nvidia-a100-sxm4-80gb"))}:generate{parameters=GenerateParameters { best_of: None, temperature: None, repetition_penalty: None, frequency_penalty: None, top_k: None, top_p: None, typical_p: None, do_sample: false, max_new_tokens: Some(128), return_full_text: Some(false), stop: [")"], truncate: None, watermark: false, details: true, decoder_input_details: true, seed: None, top_n_tokens: None, grammar: None } total_time="373.282406ms" validation_time="301.4µs" queue_time="58.959µs" inference_time="372.922197ms" time_per_token="186.461098ms" seed="None"}: text_generation_router::server: router/src/server.rs:305: Success
2024-03-06T22:34:00.013313Z  INFO compat_generate{default_return_full_text=true compute_type=Extension(ComputeType("1-nvidia-a100-sxm4-80gb"))}:generate{parameters=GenerateParameters { best_of: Some(2), temperature: None, repetition_penalty: None, frequency_penalty: None, top_k: None, top_p: Some(0.7), typical_p: None, do_sample: true, max_new_tokens: Some(128), return_full_text: Some(false), stop: [")"], truncate: None, watermark: false, details: true, decoder_input_details: true, seed: None, top_n_tokens: None, grammar: None } total_time="701.589188ms" validation_time="187.85µs" queue_time="300.879846ms" inference_time="400.521592ms" time_per_token="200.260796ms" seed="Some(8121663308226373158)"}: text_generation_router::server: router/src/server.rs:305: Success
2024-03-06T22:34:00.853564Z  INFO compat_generate{default_return_full_text=true compute_type=Extension(ComputeType("1-nvidia-a100-sxm4-80gb"))}:generate{parameters=GenerateParameters { best_of: None, temperature: None, repetition_penalty: None, frequency_penalty: None, top_k: None, top_p: None, typical_p: None, do_sample: false, max_new_tokens: Some(128), return_full_text: Some(false), stop: [")"], truncate: None, watermark: false, details: true, decoder_input_details: true, seed: None, top_n_tokens: None, grammar: None } total_time="838.047278ms" validation_time="246.059µs" queue_time="35.96µs" inference_time="837.765409ms" time_per_token="119.680772ms" seed="None"}: text_generation_router::server: router/src/server.rs:305: Success
2024-03-06T22:34:02.050050Z  INFO compat_generate{default_return_full_text=true compute_type=Extension(ComputeType("1-nvidia-a100-sxm4-80gb"))}:generate{parameters=GenerateParameters { best_of: Some(2), temperature: None, repetition_penalty: None, frequency_penalty: None, top_k: None, top_p: Some(0.7), typical_p: None, do_sample: true, max_new_tokens: Some(128), return_full_text: Some(false), stop: [")"], truncate: None, watermark: false, details: true, decoder_input_details: true, seed: None, top_n_tokens: None, grammar: None } total_time="1.194594105s" validation_time="301.579µs" queue_time="268.106817ms" inference_time="926.185909ms" time_per_token="132.312272ms" seed="Some(12980777158106879273)"}: text_generation_router::server: router/src/server.rs:305: Success
2024-03-06T22:34:02.859850Z  INFO compat_generate{default_return_full_text=true compute_type=Extension(ComputeType("1-nvidia-a100-sxm4-80gb"))}:generate{parameters=GenerateParameters { best_of: None, temperature: None, repetition_penalty: None, frequency_penalty: None, top_k: None, top_p: None, typical_p: None, do_sample: false, max_new_tokens: Some(128), return_full_text: Some(false), stop: [")"], truncate: None, watermark: false, details: true, decoder_input_details: true, seed: None, top_n_tokens: None, grammar: None } total_time="807.290596ms" validation_time="339.869µs" queue_time="18.93µs" inference_time="806.931947ms" time_per_token="115.275992ms" seed="None"}: text_generation_router::server: router/src/server.rs:305: Success
2024-03-06T22:34:03.967174Z  INFO compat_generate{default_return_full_text=true compute_type=Extension(ComputeType("1-nvidia-a100-sxm4-80gb"))}:generate{parameters=GenerateParameters { best_of: Some(2), temperature: None, repetition_penalty: None, frequency_penalty: None, top_k: None, top_p: Some(0.7), typical_p: None, do_sample: true, max_new_tokens: Some(128), return_full_text: Some(false), stop: [")"], truncate: None, watermark: false, details: true, decoder_input_details: true, seed: None, top_n_tokens: None, grammar: None } total_time="1.105341106s" validation_time="284.58µs" queue_time="268.824896ms" inference_time="836.23176ms" time_per_token="119.46168ms" seed="Some(13750410545473914160)"}: text_generation_router::server: router/src/server.rs:305: Success

I verified outside of text-generation-inference that the greedy log probabilities roughly match what one would get via calling the model directly in transformers. Full sampling also matches expectations:

========= Greedy Decode ==========
Prompt: def hello_world():
   print('Hello,
Prompt IDs: [589, 17964, 81, 5860, 2262, 664, 1459, 463, 8279, 30]
Completion " World!')"
Completion IDs: [10896, 22735]
Log Probs: [-0.4909668, -0.14099121]
total Log Prob: -0.63195801
========= Full Sampling Decode ==========
Completion " World!')"
Completion IDs: [10896, 22735]
Log Probs: [-0.46899414, -0.14025879]
total Log Prob: -0.60925293

Expected behavior

When sampling with top_p < 1 , log probabilities in details should be nearly equal to when greedy decoding or full sampling.

kingb12 avatar Mar 06 '24 22:03 kingb12

@kingb12 this is because the logprobs are computed relative to the remaining shortlist of tokens after the top_p cutoff has been applied.

A logprob value of 0 is common because this is when there's only a single token in the top_p of the original probability mass (log(1.0) = 0.0).

njhill avatar Mar 07 '24 00:03 njhill

Ah I see, that makes sense. I can see both meanings as useful as long as they're documented. If there's a document where it would make sense to include this information I'd be happy to try adding it.

For my use case, I've worked around this by just taking the returned sequence and calling the API again in greedy mode with 1 new token, then getting the log-probabilities from the prefill details with decoder_input_details=True.

kingb12 avatar Mar 07 '24 21:03 kingb12

This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.

github-actions[bot] avatar Apr 07 '24 01:04 github-actions[bot]