text-generation-inference
text-generation-inference copied to clipboard
feat: accept list as prompt and use first string
This PR allows the CompletionRequest.prompt
to be sent as a string or array of strings. When an array is sent the first value will be used if it's a string; otherwise the according error will be thrown
Fixes: https://github.com/huggingface/text-generation-inference/issues/1690 Similar to: https://github.com/vllm-project/vllm/pull/323/files
It would be pretty easy to support arrays like we do in TEI. Just push all requests in the internal queue and wait. But I feel that the client would timeout very often waiting on the slowest request from the batch and that could lead to a lot of wasted compute.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
notes
- update to handle multiple requests instead of just the first
- stream responses back with an index
- docs https://platform.openai.com/docs/api-reference/completions/create
example requests:
streaming with openai
from openai import OpenAI
YOUR_TOKEN = "YOUR_API_KEY"
# Initialize the client, pointing it to one of the available models
client = OpenAI(
base_url="http://localhost:3000/v1",
api_key=YOUR_TOKEN,
)
completion = client.completions.create(
model="gpt-3.5-turbo-instruct",
prompt=["Say", "this", "is", "a", "test"],
echo=True,
n=1,
stream=True,
max_tokens=10,
)
for chunk in completion:
print(chunk)
# Completion(id='', choices=[CompletionChoice(finish_reason='', index=4, logprobs=None, text=' =')], created=1712722135, model='google/gemma-7b', object='text_completion', system_fingerprint='1.4.5-native', usage=None)
# Completion(id='', choices=[CompletionChoice(finish_reason='', index=4, logprobs=None, text=' ')], created=1712722135, model='google/gemma-7b', object='text_completion', system_fingerprint='1.4.5-native', usage=None)
# Completion(id='', choices=[CompletionChoice(finish_reason='', index=4, logprobs=None, text='1')], created=1712722136, model='google/gemma-7b', object='text_completion', system_fingerprint='1.4.5-native', usage=None)
# Completion(id='', choices=[CompletionChoice(finish_reason='', index=4, logprobs=None, text='0')], created=1712722136, model='google/gemma-7b', object='text_completion', system_fingerprint='1.4.5-native', usage=None)
# ...
with aiohttp (streaming)
from aiohttp import ClientSession
import json
import asyncio
base_url = "http://localhost:3000"
request = {
"model": "tgi",
"prompt": [
"What color is the sky?",
"Is water wet?",
"What is the capital of France?",
"def mai",
],
"max_tokens": 10,
"seed": 0,
"stream": True,
}
url = f"{base_url}/v1/completions"
async def main():
async with ClientSession() as session:
async with session.post(url, json=request) as response:
async for chunk in response.content.iter_any():
chunk = chunk.decode().split("\n\n")
chunk = [c.replace("data:", "") for c in chunk]
chunk = [c for c in chunk if c]
chunk = [json.loads(c) for c in chunk]
for c in chunk:
print(c)
asyncio.run(main())
# {'id': '', 'object': 'text_completion', 'created': 1712863765, 'choices': [{'index': 1, 'text': ' a', 'logprobs': None, 'finish_reason': ''}], 'model': 'google/gemma-7b', 'system_fingerprint': '1.4.5-native'}
# {'id': '', 'object': 'text_completion', 'created': 1712863765, 'choices': [{'index': 2, 'text': ' Paris', 'logprobs': None, 'finish_reason': ''}], 'model': 'google/gemma-7b', 'system_fingerprint': '1.4.5-native'}
# {'id': '', 'object': 'text_completion', 'created': 1712863765, 'choices': [{'index': 3, 'text': 'nic', 'logprobs': None, 'finish_reason': ''}], 'model': 'google/gemma-7b', 'system_fingerprint': '1.4.5-native'}
# {'id': '', 'object': 'text_completion', 'created': 1712863765, 'choices': [{'index': 0, 'text': ' blue', 'logprobs': None, 'finish_reason': ''}], 'model': 'google/gemma-7b', 'system_fingerprint': '1.4.5-native'}
# {'id': '', 'object': 'text_completion', 'created': 1712863765, 'choices': [{'index': 1, 'text': ' liquid', 'logprobs': None, 'finish_reason': ''}], 'model': 'google/gemma-7b', 'system_fingerprint': '1.4.5-native'}
sync with requests (non streaming)
import requests
base_url = "http://localhost:3000"
response = requests.post(
f"{base_url}/v1/completions",
json={
"model": "tgi",
"prompt": ["Say", "this", "is", "a", "test"],
"max_tokens": 2,
"seed": 0,
},
stream=False,
)
response = response.json()
print(response)
# {'id': '', 'object': 'text_completion', 'created': 1712722405, 'model': 'google/gemma-7b', 'system_fingerprint': '1.4.5-native', 'choices': [{'index': 0, 'text': " you'", 'logprobs': None, 'finish_reason': 'length'}, {'index': 1, 'text': ' the sequence', 'logprobs': None, 'finish_reason': 'length'}, {'index': 2, 'text': '_cases', 'logprobs': None, 'finish_reason': 'length'}, {'index': 3, 'text': '.\n\n', 'logprobs': None, 'finish_reason': 'length'}, {'index': 4, 'text': '. ', 'logprobs': None, 'finish_reason': 'length'}], 'usage': {'prompt_tokens': 10, 'completion_tokens': 10, 'total_tokens': 20}}
**note the client library intentionally does not include a completions
method because this is a legacy API. The changes in this PR are to align with the API and address integrations with existing tools (langchain retrieval chain)
Should be good after rebase.
**failing client tests do not seem related to these changes and are resolved here: https://github.com/huggingface/text-generation-inference/pull/1751
... The logs are rather poor compared to the regular endpoints.
2024-04-16T10:42:49.931556Z INFO text_generation_router::server: router/src/server.rs:500: Success
vs
2024-04-16T10:42:56.302342Z INFO generate_stream{parameters=GenerateParameters { best_of: None, temperature: None, repetition_penalty: None, frequency_penalty: None, top_k: None, top_p: None, typical_p: None, do_sample: false, max_new_tokens: Some(10), return_full_text: None, stop: [], truncate: None, watermark: false, details: false, decoder_input_details: false, seed: None, top_n_tokens: None, grammar: None } total_time="429.831681ms" validation_time="217.73µs" queue_time="64.823µs" inference_time="429.549248ms" time_per_token="42.954924ms" seed="None"}: text_generation_router::server: router/src/server.rs:500: Success
yea its a bit strange that the same logging line produces more output in one case. Any ideas on how to have it emit the same output?
... The logs are rather poor compared to the regular endpoints.
2024-04-16T10:42:49.931556Z INFO text_generation_router::server: router/src/server.rs:500: Success
vs
2024-04-16T10:42:56.302342Z INFO generate_stream{parameters=GenerateParameters { best_of: None, temperature: None, repetition_penalty: None, frequency_penalty: None, top_k: None, top_p: None, typical_p: None, do_sample: false, max_new_tokens: Some(10), return_full_text: None, stop: [], truncate: None, watermark: false, details: false, decoder_input_details: false, seed: None, top_n_tokens: None, grammar: None } total_time="429.831681ms" validation_time="217.73µs" queue_time="64.823µs" inference_time="429.549248ms" time_per_token="42.954924ms" seed="None"}: text_generation_router::server: router/src/server.rs:500: Success
yea its a bit strange that the same logging line produces more output in one case. Any ideas on how to have it emit the same output?
Should be about the span capture
logs are now bubbled up to the calling function and output the same information as generate
and generate_stream
change: generate_internal
and generate_stream_internal
now take a span
as an argument and is passed to tracing::info
as a parent span.