text-generation-inference icon indicating copy to clipboard operation
text-generation-inference copied to clipboard

feat: accept list as prompt and use first string

Open drbh opened this issue 4 months ago • 3 comments

This PR allows the CompletionRequest.prompt to be sent as a string or array of strings. When an array is sent the first value will be used if it's a string; otherwise the according error will be thrown

Fixes: https://github.com/huggingface/text-generation-inference/issues/1690 Similar to: https://github.com/vllm-project/vllm/pull/323/files

drbh avatar Apr 03 '24 20:04 drbh

It would be pretty easy to support arrays like we do in TEI. Just push all requests in the internal queue and wait. But I feel that the client would timeout very often waiting on the slowest request from the batch and that could lead to a lot of wasted compute.

OlivierDehaene avatar Apr 08 '24 16:04 OlivierDehaene

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

notes

  • update to handle multiple requests instead of just the first
  • stream responses back with an index
  • docs https://platform.openai.com/docs/api-reference/completions/create

drbh avatar Apr 09 '24 16:04 drbh

example requests:

streaming with openai

from openai import OpenAI

YOUR_TOKEN = "YOUR_API_KEY"

# Initialize the client, pointing it to one of the available models
client = OpenAI(
    base_url="http://localhost:3000/v1",
    api_key=YOUR_TOKEN,
)

completion = client.completions.create(
    model="gpt-3.5-turbo-instruct",
    prompt=["Say", "this", "is", "a", "test"],
    echo=True,
    n=1,
    stream=True,
    max_tokens=10,
)

for chunk in completion:
    print(chunk)

# Completion(id='', choices=[CompletionChoice(finish_reason='', index=4, logprobs=None, text=' =')], created=1712722135, model='google/gemma-7b', object='text_completion', system_fingerprint='1.4.5-native', usage=None)
# Completion(id='', choices=[CompletionChoice(finish_reason='', index=4, logprobs=None, text=' ')], created=1712722135, model='google/gemma-7b', object='text_completion', system_fingerprint='1.4.5-native', usage=None)
# Completion(id='', choices=[CompletionChoice(finish_reason='', index=4, logprobs=None, text='1')], created=1712722136, model='google/gemma-7b', object='text_completion', system_fingerprint='1.4.5-native', usage=None)
# Completion(id='', choices=[CompletionChoice(finish_reason='', index=4, logprobs=None, text='0')], created=1712722136, model='google/gemma-7b', object='text_completion', system_fingerprint='1.4.5-native', usage=None)
# ...

with aiohttp (streaming)

from aiohttp import ClientSession
import json
import asyncio

base_url = "http://localhost:3000"


request = {
    "model": "tgi",
    "prompt": [
        "What color is the sky?",
        "Is water wet?",
        "What is the capital of France?",
        "def mai",
    ],
    "max_tokens": 10,
    "seed": 0,
    "stream": True,
}

url = f"{base_url}/v1/completions"


async def main():

    async with ClientSession() as session:
        async with session.post(url, json=request) as response:
            async for chunk in response.content.iter_any():
                chunk = chunk.decode().split("\n\n")
                chunk = [c.replace("data:", "") for c in chunk]
                chunk = [c for c in chunk if c]
                chunk = [json.loads(c) for c in chunk]

                for c in chunk:
                    print(c)

asyncio.run(main())

# {'id': '', 'object': 'text_completion', 'created': 1712863765, 'choices': [{'index': 1, 'text': ' a', 'logprobs': None, 'finish_reason': ''}], 'model': 'google/gemma-7b', 'system_fingerprint': '1.4.5-native'}
# {'id': '', 'object': 'text_completion', 'created': 1712863765, 'choices': [{'index': 2, 'text': ' Paris', 'logprobs': None, 'finish_reason': ''}], 'model': 'google/gemma-7b', 'system_fingerprint': '1.4.5-native'}
# {'id': '', 'object': 'text_completion', 'created': 1712863765, 'choices': [{'index': 3, 'text': 'nic', 'logprobs': None, 'finish_reason': ''}], 'model': 'google/gemma-7b', 'system_fingerprint': '1.4.5-native'}
# {'id': '', 'object': 'text_completion', 'created': 1712863765, 'choices': [{'index': 0, 'text': ' blue', 'logprobs': None, 'finish_reason': ''}], 'model': 'google/gemma-7b', 'system_fingerprint': '1.4.5-native'}
# {'id': '', 'object': 'text_completion', 'created': 1712863765, 'choices': [{'index': 1, 'text': ' liquid', 'logprobs': None, 'finish_reason': ''}], 'model': 'google/gemma-7b', 'system_fingerprint': '1.4.5-native'}

sync with requests (non streaming)

import requests

base_url = "http://localhost:3000"

response = requests.post(
    f"{base_url}/v1/completions",
    json={
        "model": "tgi",
        "prompt": ["Say", "this", "is", "a", "test"],
        "max_tokens": 2,
        "seed": 0,
    },
    stream=False,
)
response = response.json()

print(response)
# {'id': '', 'object': 'text_completion', 'created': 1712722405, 'model': 'google/gemma-7b', 'system_fingerprint': '1.4.5-native', 'choices': [{'index': 0, 'text': " you'", 'logprobs': None, 'finish_reason': 'length'}, {'index': 1, 'text': ' the sequence', 'logprobs': None, 'finish_reason': 'length'}, {'index': 2, 'text': '_cases', 'logprobs': None, 'finish_reason': 'length'}, {'index': 3, 'text': '.\n\n', 'logprobs': None, 'finish_reason': 'length'}, {'index': 4, 'text': '. ', 'logprobs': None, 'finish_reason': 'length'}], 'usage': {'prompt_tokens': 10, 'completion_tokens': 10, 'total_tokens': 20}}

drbh avatar Apr 10 '24 04:04 drbh

**note the client library intentionally does not include a completions method because this is a legacy API. The changes in this PR are to align with the API and address integrations with existing tools (langchain retrieval chain)

drbh avatar Apr 10 '24 15:04 drbh

Should be good after rebase.

Narsil avatar Apr 16 '24 14:04 Narsil

**failing client tests do not seem related to these changes and are resolved here: https://github.com/huggingface/text-generation-inference/pull/1751

drbh avatar Apr 16 '24 17:04 drbh

... The logs are rather poor compared to the regular endpoints.

2024-04-16T10:42:49.931556Z  INFO text_generation_router::server: router/src/server.rs:500: Success

vs

2024-04-16T10:42:56.302342Z  INFO generate_stream{parameters=GenerateParameters { best_of: None, temperature: None, repetition_penalty: None, frequency_penalty: None, top_k: None, top_p: None, typical_p: None, do_sample: false, max_new_tokens: Some(10), return_full_text: None, stop: [], truncate: None, watermark: false, details: false, decoder_input_details: false, seed: None, top_n_tokens: None, grammar: None } total_time="429.831681ms" validation_time="217.73µs" queue_time="64.823µs" inference_time="429.549248ms" time_per_token="42.954924ms" seed="None"}: text_generation_router::server: router/src/server.rs:500: Success

yea its a bit strange that the same logging line produces more output in one case. Any ideas on how to have it emit the same output?

drbh avatar Apr 16 '24 17:04 drbh

... The logs are rather poor compared to the regular endpoints.

2024-04-16T10:42:49.931556Z  INFO text_generation_router::server: router/src/server.rs:500: Success

vs

2024-04-16T10:42:56.302342Z  INFO generate_stream{parameters=GenerateParameters { best_of: None, temperature: None, repetition_penalty: None, frequency_penalty: None, top_k: None, top_p: None, typical_p: None, do_sample: false, max_new_tokens: Some(10), return_full_text: None, stop: [], truncate: None, watermark: false, details: false, decoder_input_details: false, seed: None, top_n_tokens: None, grammar: None } total_time="429.831681ms" validation_time="217.73µs" queue_time="64.823µs" inference_time="429.549248ms" time_per_token="42.954924ms" seed="None"}: text_generation_router::server: router/src/server.rs:500: Success

yea its a bit strange that the same logging line produces more output in one case. Any ideas on how to have it emit the same output?

Should be about the span capture

Narsil avatar Apr 16 '24 20:04 Narsil

logs are now bubbled up to the calling function and output the same information as generate and generate_stream

change: generate_internal and generate_stream_internal now take a span as an argument and is passed to tracing::info as a parent span.

drbh avatar Apr 17 '24 02:04 drbh