text-generation-inference
text-generation-inference copied to clipboard
Adding support for CTranslate2 acceleration
Feature request
The feature would be to support accelerated inference with the CTranslate2 framework https://github.com/OpenNMT/CTranslate2
Motivation
Reasons to CTranslate2
faster float16 generation
In my case, outperforms VLLM #478 by a factor of 180% and Transformers by 190%. This is measured for gpt_bigcode_starcoder, with input length of 10, batch_size of 2, max_new_tokens=32
even faster int8 quantization
- int8 quantization that in most cases outperforms float16 inference, quality loss is minimal while being identical to
Wide range of model support
This could be a new "best effort" option - https://opennmt.net/CTranslate2/guides/transformers.html
Stability
Compared to other frameworks, cross-platform unit tests. It is the backend of LibreTranslate https://github.com/LibreTranslate/LibreTranslate
Streaming Support
https://opennmt.net/CTranslate2/generation.html?highlight=stream
Your contribution
Here is an example script, comparing the three frameworks with ray[serve] (replacement of the dynamic batching in TGI).
With the code below, my timings are:
In: input: "def hello_world" -> 32 new tokens
Median time of 10:
- Transformers : 287.6ms
- VLLM : 278.8ms
- CTranslate2 : 193.5ms
from fastapi import FastAPI
from ray import serve, get_gpu_ids
import time
from typing import List, Dict
from pydantic import BaseModel
from enum import Enum
app = FastAPI(docs_url="/docs") # lets not do any authentication for now
class HFAPIRequest(BaseModel):
inputs: str
parameters: dict = {}
class Framework(Enum):
ctranslate2 = 1
transformers = 2
vllm = 3
@serve.deployment
@serve.ingress(app)
class FastAPIWrapper:
def __init__(self, model_handle) -> None:
self.model_handle = model_handle
@app.get("/v1/heartbeat")
async def complete(self) -> float:
"""check that the application is running"""
return time.time()
@app.post("/v1/complete")
async def complete(self, data: HFAPIRequest) -> Dict[str,str]:
"""send json-data per http request. returns a completion"""
try:
d_dict=data.dict()
# send the call, get ray_ref as receipt
ray_ref = await self.model_handle.predict.remote(inputs=d_dict["inputs"])
# wait until ray_ref finished
completion = await ray_ref
return {"generated_text": completion}
except Exception as ex:
return {"error": str(ex)}
@serve.deployment(ray_actor_options={"num_gpus": 1})
class CodeCompletionModel:
def __init__(self, model_name = "michaelfeil/ct2fast-gpt_bigcode-santacoder", framework: Framework = Framework.ctranslate2 ) -> None:
self.framework = framework
gpus = bool(get_gpu_ids())
if self.framework == Framework.ctranslate2:
# 1784MiB VRAM, 314.3ms for generation
from hf_hub_ctranslate2 import GeneratorCT2fromHfHub
self.model = GeneratorCT2fromHfHub(
# use a pre-quantized model here
model_name_or_path=model_name,
device="cuda" if gpus else "cpu",
compute_type="int8_float16" if gpus else "int8",
)
elif self.framework == Framework.transformers:
# 2918MiB VRAM, ~1146.8ms for generation
from transformers import pipeline
import torch
self.model = pipeline(
task="text-generation",
model="bigcode/gpt_bigcode-santacoder",
device="cuda:0" if gpus else "cpu",
torch_dtype=torch.float16
)
elif self.framework == Framework.vllm:
from vllm import LLM, SamplingParams
self.sampling_params = SamplingParams(max_tokens=32)
self.model = LLM(model="bigcode/gpt_bigcode-santacoder")
else:
raise ValueError(
f"{self.framework} is not a supported framework."
f"Please use one of the following: {list(Framework)}"
)
@serve.batch(max_batch_size=4)
async def predict(self, inputs: List[str]) -> List[str]:
if self.framework == Framework.ctranslate2:
return self.model.generate(
text=inputs,
max_length=32,
include_prompt_in_result=False
)
elif self.framework == Framework.transformers:
out = self.model(inputs, max_new_tokens=32, return_full_text=True)
return [o[0]["generated_text"] for o in out]
elif self.framework == Framework.vllm:
outputs = self.model.generate(inputs, self.sampling_params)
return [o.outputs[0].text for o in outputs]
deploy_handle = FastAPIWrapper.bind(model_handle=CodeCompletionModel.bind(framework=Framework.ctranslate2))
if __name__ == "__main__":
serve.run(deploy_handle, port=5000)
The batch size 2 seems very small.
Feel free to try your own example.
if __name__ == "__main__":
import timeit
for framework in Framework:
start = timeit.default_timer()
model = CodeCompletionModel(framework=framework)
end = timeit.default_timer()
model.predict_batch(["warmup"])
inp = [("def " * i) for i in range(1,16)] # batch size 16 with increasing number of tokens
t = timeit.timeit("model.predict_batch(inp)", globals=locals(), number=10) / 10
print("framework: ", framework, " mean time per batch: ", t, "s, model load time was ", end -start, "s")
e.g. batch size 64, with increasing size
framework: Framework.vllm mean time per batch: 0.9124439172912389 s, model load time was 6.033201194833964 s
framework: Framework.ctranslate2 mean time per batch: 0.5082830318249763 s, model load time was 7.9577884785830975 s
@michaelfeil Hi, I was working on ctranslate2 version of wizardcoder and the latency with max_new_token=128 is a lot higher (about 4s) compared to yours, may I ask what machine are you using? or did you modify anything to make it faster?
In my experience, the mean time of your example should grow approximate 4X to 2s if max_new_token=128, which still 2X faster than my results
@michaelfeil Hi, I was working on ctranslate2 version of wizardcoder and the latency with max_new_token=128 is a lot higher (about 4s) compared to yours, may I ask what machine are you using? or did you modify anything to make it faster?
In my experience, the mean time of your example should grow approximate 4X to 2s if max_new_token=128, which still 2X faster than my results
Yung-Kai The above uses santacoder (1.1B params), with 32 tokens. Whats the relative speed difference across all frameworks?
@michaelfeil Thanks for your reply, I got around 16s on HF, and 5s for 4 bit quantization with autoGPTQ. Haven't try vLLM due to OOM in my case, I enlarge the max_num_seq. I guess starcoder should take longer than santacoder, but not sure how long it would be
This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.