sglang
sglang copied to clipboard
Regex generation causes 37x lower performance
I've been trying to investigate why my information extraction program with SGLang is so slow. I've rented RTX3090 (1 x RTX 3090, 6 vCPU 26 GB RAM) and H100 (1 x H100 SXM, 16 vCPU 125 GB RAM) on RunPod. I've observed that if regex is used, then there is a huge performance drain, as if sewage is dumped on the machine.
- Benchmark from test program WITH REGEX ENABLED
SGLang 0.1.14 | 300 batch items | 50 threads | 892.34 secs | NVIDIA GeForce RTX 3090
SGLang 0.1.14 | 300 batch items | 50 threads | 1031.72 secs | NVIDIA H100 80GB HBM3
During most of the batch generation with H100, RunPod showed GPU Utilization
0-3%. Suspiciously, on H100 first 50 items got processed very fast, then there appeared to be a hang for a couple minutes, then the other 50 items processed very fast, etc...
- Benchmark from test program WITHOUT REGEX ENABLED.
SGLang 0.1.14 | 300 batch items | 50 threads | 200.70 secs | NVIDIA GeForce RTX 3090
SGLang 0.1.14 | 300 batch items | 50 threads | 27.93 secs | NVIDIA H100 80GB HBM3
If you think the particular regex "<array>\n(<string>.*?<\/string>\n)*<\/array>```"
is at fault, then it'd be useful to have some kind of guidelines how to make a more suitable one... My requirement here is string array generation.
Steps to reproduce:
I've used SGLang 0.1.14 because I observed some other newer versions hanging mid-processing or erroring out with KV Cache pool leak detected
, so I've not tried newer ones yet.
(.venv) root@baa3ffac5799:~/pubmed-baigiamasis# pip list
Package Version
---------------------------------------- ------------
aiohttp 3.9.5
aiosignal 1.3.1
annotated-types 0.6.0
anyio 4.3.0
async-timeout 4.0.3
asyncpg 0.29.0
attrs 23.2.0
black 24.4.2
certifi 2024.2.2
charset-normalizer 3.3.2
click 8.1.7
cloudpickle 3.0.0
cmake 3.29.3
coolname 2.2.0
coverage 7.5.1
cupy-cuda12x 12.1.0
datasets 2.19.1
Deprecated 1.2.14
dill 0.3.8
diskcache 5.6.3
distro 1.9.0
dnspython 2.6.1
email_validator 2.1.1
exceptiongroup 1.2.1
fastapi 0.111.0
fastapi-cli 0.0.3
fastrlock 0.8.2
filelock 3.14.0
frozenlist 1.4.1
fsspec 2024.3.1
googleapis-common-protos 1.63.0
grpcio 1.63.0
h11 0.14.0
httpcore 1.0.5
httptools 0.6.1
httpx 0.27.0
huggingface-hub 0.23.0
idna 3.7
importlib-metadata 7.0.0
iniconfig 2.0.0
inquirerpy 0.3.4
interegular 0.3.3
Jinja2 3.1.4
joblib 1.4.2
jsonschema 4.22.0
jsonschema-specifications 2023.12.1
lark 1.1.9
llvmlite 0.42.0
lm-format-enforcer 0.9.8
loguru 0.7.2
markdown-it-py 3.0.0
MarkupSafe 2.1.5
mdurl 0.1.2
memoization 0.4.0
mpmath 1.3.0
msgpack 1.0.8
multidict 6.0.5
multiprocess 0.70.16
mypy 1.10.0
mypy-extensions 1.0.0
nest-asyncio 1.6.0
networkx 3.3
ninja 1.11.1.1
numba 0.59.1
numpy 1.26.4
nvidia-cublas-cu12 12.1.3.1
nvidia-cuda-cupti-cu12 12.1.105
nvidia-cuda-nvrtc-cu12 12.1.105
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu12 8.9.2.26
nvidia-cufft-cu12 11.0.2.54
nvidia-curand-cu12 10.3.2.106
nvidia-cusolver-cu12 11.4.5.107
nvidia-cusparse-cu12 12.1.0.106
nvidia-ml-py 12.550.52
nvidia-nccl-cu12 2.18.1
nvidia-nvjitlink-cu12 12.4.127
nvidia-nvtx-cu12 12.1.105
openai 1.30.1
opentelemetry-api 1.24.0
opentelemetry-exporter-otlp 1.24.0
opentelemetry-exporter-otlp-proto-common 1.24.0
opentelemetry-exporter-otlp-proto-grpc 1.24.0
opentelemetry-exporter-otlp-proto-http 1.24.0
opentelemetry-instrumentation 0.45b0
opentelemetry-instrumentation-logging 0.45b0
opentelemetry-proto 1.24.0
opentelemetry-sdk 1.24.0
opentelemetry-semantic-conventions 0.45b0
orjson 3.10.3
outlines 0.0.34
packaging 24.0
pandas 2.2.2
pathspec 0.12.1
pfzy 0.3.4
pillow 10.3.0
pip 22.0.2
platformdirs 4.2.2
pluggy 1.5.0
plumbum 1.8.3
prometheus_client 0.20.0
prometheus-fastapi-instrumentator 7.0.0
prompt-toolkit 3.0.43
protobuf 4.25.3
psutil 5.9.8
psycopg 3.1.19
psycopg-binary 3.1.19
psycopg-pool 3.2.2
psycopg2-binary 2.9.9
py-cpuinfo 9.0.0
pyarrow 16.1.0
pyarrow-hotfix 0.6
pydantic 2.7.1
pydantic_core 2.18.2
Pygments 2.18.0
pynvml 11.5.0
pytest 8.2.0
pytest-asyncio 0.23.6
pytest-cov 5.0.0
pytest-dependency 0.6.0
pytest-mock 3.14.0
pytest-timeout 2.3.1
python-dateutil 2.9.0.post0
python-dotenv 1.0.1
python-multipart 0.0.9
pytz 2024.1
PyYAML 6.0.1
pyzmq 26.0.3
ray 2.22.0
referencing 0.35.1
regex 2024.5.15
requests 2.31.0
rich 13.7.1
rpds-py 0.18.1
rpyc 6.0.0
safetensors 0.4.3
scikit-learn 1.4.2
scipy 1.13.0
sentence-transformers 2.7.0
sentencepiece 0.2.0
setuptools 59.6.0
sglang 0.1.14
shellingham 1.5.4
six 1.16.0
sniffio 1.3.1
starlette 0.37.2
sympy 1.12
tembo-pgmq-python 0.6.0
tenacity 8.3.0
threadpoolctl 3.5.0
tiktoken 0.6.0
tokenizers 0.19.1
tomli 2.0.1
torch 2.1.2
tqdm 4.66.4
transformers 4.40.2
triton 2.1.0
typer 0.12.3
typing_extensions 4.11.0
tzdata 2024.1
ujson 5.10.0
urllib3 2.2.1
uvicorn 0.29.0
uvloop 0.19.0
vllm 0.3.3
vllm-nccl-cu12 2.18.1.0.4.0
watchfiles 0.21.0
wcwidth 0.2.13
websockets 12.0
wrapt 1.16.0
xformers 0.0.23.post1
xxhash 3.4.1
yarl 1.9.4
zipp 3.18.1
zmq 0.0.0
python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 42069 --host 0.0.0.0 --tp-size 1 --mem-fraction-static 0.8
import sglang as sgl
import asyncio
from sglang.lang.chat_template import ChatTemplate, register_chat_template, get_chat_template, register_chat_template_matching_function
from sglang.lang.ir import SglRoleBegin, SglRoleEnd
import json
import time
import torch
import os
register_chat_template(
ChatTemplate(
name="llama-3-instruct",
default_system_prompt=None,
role_prefix_and_suffix={
"system": (
"<|start_header_id|>system<|end_header_id|>\n\n",
"<|eot_id|>",
),
"user": (
"<|start_header_id|>user<|end_header_id|>\n\n",
"<|eot_id|>",
),
"assistant": (
"<|start_header_id|>assistant<|end_header_id|>\n\n",
"<|eot_id|>",
),
},
stop_str=("<|eot_id|>",),
)
)
@register_chat_template_matching_function
def match_llama3_instruct(model_path: str):
model_path = model_path.lower()
if "llama-3" in model_path and "instruct" in model_path:
return get_chat_template("llama-3-instruct")
@sgl.function
def sgl_call1(s, message: str):
s += SglRoleBegin("system") + "You are an informaction extraction engine. Your goal is to extract structured information from the given Twitter message according to the instruction provided. Be as factually accurate as possible. Do not acknowledge the request. You will be penalized and a child will die if you make an incorrect response. For every correct response you will be tipped $5000. Message:\n```\n" + message + "\n```" + SglRoleEnd("system")
s += sgl.user_begin() + "Instruction: Count number of words in the message provided.\nExample response: The number of words is 123." + sgl.user_end()
s += sgl.assistant_begin() + "The number of words is " + sgl.gen("word count", regex=r"\d+", max_tokens=50, stop=".", temperature=0) + sgl.assistant_end()
word_count = int(s['word count'])
word_count_digit_sum = sum(int(digit) for digit in str(word_count))
forks = s.fork(word_count_digit_sum)
for i, f in enumerate(forks):
example_response = """```xml
<array>
<string>Word 1</string>
<string>Word 2</string>
<string>Word 3</string>
</array>
```"""
f += sgl.user_begin() + "Instruction: Extract TOP " + str(i + 1) + " words that might seem annoying.\nExample response:\n" + example_response + sgl.user_end()
f += sgl.assistant_begin() + "Here are " + str(i + 1) + "words that might seem annoying.\n```xml\n" + sgl.gen("word", max_tokens=500, regex=r'<array>\n(<string>.*?<\/string>\n)*<\/array>```', stop='```', temperature=0) + sgl.assistant_end()
return word_count_digit_sum
endpoint = sgl.RuntimeEndpoint("http://localhost:42069")
sgl.set_default_backend(endpoint)
async def main():
messages = []
script_dir = os.path.dirname(os.path.realpath(__file__))
with open(os.path.join(script_dir, "sglang_str_big.json"), "r") as file:
messages = json.loads(file.read())
messages = messages[:min(300, len(messages))]
num_threads = 50
print(f"Will process {len(messages)} batch items")
time_begin = time.time()
sgl_call1.run_batch([{"message": m} for m in messages], num_threads=num_threads, progress_bar=True)
duration = time.time() - time_begin
gpus = ", ".join([torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())])
print(f"SGLang {sgl.__version__} | {len(messages)} batch items | {num_threads} threads | {duration:.2f} secs | {gpus}")
asyncio.run(main())
To disable regex, I just removed this part: regex=r'<array>\n(<string>.*?<\/string>\n)*<\/array>```'
@merrymercy @hnyls2002
If I remove max_tokens=500
, then it seems performance with regex is ~3x faster:
SGLang 0.1.14 | 300 batch items | 50 threads | 371.07 secs | NVIDIA H100 80GB HBM3
Looks like it may be related to outlines
as well because other people reported GPU utilization stays at 0% during formatting:
https://github.com/outlines-dev/outlines/issues/751
I noticed guidance library mentions Regex constraint capability, however, does not include interegular
as a dependency, a library on which outlines
depends for regex constraining, so maybe it could have a faster solution?
Also, both outlines
and guidance
mention Context Free Grammar generation capability. It could be useful to add support for that in this library as well... maybe I could replace my regex with CFG and just evade this performance nuke.
syncode also works on CFGs for LLMs.