[Serve] Proxy w/ retry
Using proxy on our load balancer w/ retry. It is useful for spot-based serving.
TODO:
- [x] Test and make sure it works
- [x] Benchmark, compare w/ previous implementation
- [x] Support streaming
- [x] Make it more robust
- [x] Update documentation
Tested (run the relevant ones):
- [x] Code formatting:
bash format.sh - [x] Any manual or new tests for this PR (please specify below)
import fastapi, uvicorn, asyncio
import multiprocessing
from sky.serve import load_balancer
REPLICA_URLS = []
PROCESSES = []
CONTROLLER_PORT = 20018
WORD_TO_STREAM = 'Hello world! Nice to meet you!'
TIME_TO_SLEEP = 0.2
def _start_streaming_replica(port):
app = fastapi.FastAPI()
@app.get('/')
async def stream():
async def generate_words():
for word in WORD_TO_STREAM.split():
yield word + "\n"
await asyncio.sleep(TIME_TO_SLEEP)
return fastapi.responses.StreamingResponse(generate_words(), media_type="text/plain")
@app.get('/non-stream')
async def non_stream():
return {'message': WORD_TO_STREAM}
@app.get('/error')
async def error():
raise fastapi.HTTPException(status_code=500, detail='Internal Server Error')
uvicorn.run(app, host='0.0.0.0', port=port)
def _start_streaming_replica_in_process(port):
global PROCESSES, REPLICA_URLS
STREAMING_REPLICA_PROCESS = multiprocessing.Process(target=_start_streaming_replica, args=(port,))
STREAMING_REPLICA_PROCESS.start()
PROCESSES.append(STREAMING_REPLICA_PROCESS)
REPLICA_URLS.append(f'0.0.0.0:{port}')
def _start_controller():
app = fastapi.FastAPI()
@app.post('/controller/load_balancer_sync')
async def lb_sync(request: fastapi.Request):
return {'ready_replica_urls': REPLICA_URLS}
uvicorn.run(app, host='0.0.0.0', port=CONTROLLER_PORT)
def _start_controller_in_process():
global PROCESSES
CONTROLLER_PROCESS = multiprocessing.Process(target=_start_controller)
CONTROLLER_PROCESS.start()
PROCESSES.append(CONTROLLER_PROCESS)
if __name__ == '__main__':
try:
_start_streaming_replica_in_process(7001)
_start_streaming_replica_in_process(7002)
_start_streaming_replica_in_process(7003)
_start_streaming_replica_in_process(7004)
_start_controller_in_process()
lb = load_balancer.SkyServeLoadBalancer(
controller_url=f'http://0.0.0.0:{CONTROLLER_PORT}',
load_balancer_port=7000)
lb.run()
finally:
for p in PROCESSES:
p.terminate()
- [x] All skyserve smoke tests:
pytest tests/test_smoke.py --serve - [ ] Relevant individual smoke tests:
pytest tests/test_smoke.py::test_fill_in_the_name - [ ] Backward compatibility tests:
bash tests/backward_comaptibility_tests.sh
@Michaelvll This is ready for a look now 🫡 I'm still running smoke tests and adding a new streaming test for now, will report back later
All skyserve smoke test passed 🫡
TODO:
- [x] investigate side effects for max num connections = 1000
- [x] test if an abortion happens on the client side, if the worker will stop generate
- [x] run client.aclose in background; https://github.com/skypilot-org/skypilot/pull/2735#discussion_r1594826172
Tested for abortion and it works as well. Use the following script to launch LB & worker, and http://0.0.0.0:7000/, then Ctrl+C the curl command. The logging for ===========WORKER will stop.
UPD: I'll test e2e LLM workloads later.
import fastapi, uvicorn, asyncio
import multiprocessing
from sky.serve import load_balancer
REPLICA_URLS = []
PROCESSES = []
CONTROLLER_PORT = 20018
WORD_TO_STREAM = 'Hello world! Nice to meet you!'
TIME_TO_SLEEP = 0.2
def _start_streaming_replica(port):
app = fastapi.FastAPI()
@app.get('/')
async def stream():
async def generate_words():
for word in WORD_TO_STREAM.split()*1000:
yield word + "\n"
print('===========WORKER', word)
await asyncio.sleep(TIME_TO_SLEEP)
return fastapi.responses.StreamingResponse(generate_words(), media_type="text/plain")
@app.get('/non-stream')
async def non_stream():
return {'message': WORD_TO_STREAM}
@app.get('/error')
async def error():
raise fastapi.HTTPException(status_code=500, detail='Internal Server Error')
uvicorn.run(app, host='0.0.0.0', port=port)
def _start_streaming_replica_in_process(port):
global PROCESSES, REPLICA_URLS
STREAMING_REPLICA_PROCESS = multiprocessing.Process(target=_start_streaming_replica, args=(port,))
STREAMING_REPLICA_PROCESS.start()
PROCESSES.append(STREAMING_REPLICA_PROCESS)
REPLICA_URLS.append(f'http://0.0.0.0:{port}')
def _start_controller():
app = fastapi.FastAPI()
flip_flop = False
@app.post('/controller/load_balancer_sync')
async def lb_sync(request: fastapi.Request):
return {'ready_replica_urls': REPLICA_URLS}
uvicorn.run(app, host='0.0.0.0', port=CONTROLLER_PORT)
def _start_controller_in_process():
global PROCESSES
CONTROLLER_PROCESS = multiprocessing.Process(target=_start_controller)
CONTROLLER_PROCESS.start()
PROCESSES.append(CONTROLLER_PROCESS)
if __name__ == '__main__':
try:
_start_streaming_replica_in_process(7001)
_start_controller_in_process()
lb = load_balancer.SkyServeLoadBalancer(
controller_url=f'http://0.0.0.0:{CONTROLLER_PORT}',
load_balancer_port=7000)
lb.run()
finally:
for p in PROCESSES:
p.terminate()
Just tested with a modified version of fastchat and the abortion works well. I uses the OpenAI Client here and manually Ctrl+C to abort the request.
YAML i used:
service:
readiness_probe: /v1/models
replicas: 1
resources:
ports: 8087
memory: 32+
accelerators: L4:1
disk_size: 1024
disk_tier: best
envs:
MODEL_SIZE: 7
HF_TOKEN: <huggingface-token>
setup: |
conda activate chatbot
if [ $? -ne 0 ]; then
conda create -n chatbot python=3.9 -y
conda activate chatbot
fi
# Install dependencies
git clone https://github.com/cblmemo/fschat-print-streaming.git fschat
cd fschat
git switch print-stream
pip install -e ".[model_worker,webui]"
python -c "import huggingface_hub; huggingface_hub.login('${HF_TOKEN}')"
run: |
conda activate chatbot
echo 'Starting controller...'
python -u -m fastchat.serve.controller --host 0.0.0.0 > ~/controller.log 2>&1 &
sleep 10
echo 'Starting model worker...'
python -u -m fastchat.serve.model_worker --host 0.0.0.0 \
--model-path meta-llama/Llama-2-${MODEL_SIZE}b-chat-hf \
--num-gpus $SKYPILOT_NUM_GPUS_PER_NODE 2>&1 \
| tee model_worker.log &
echo 'Waiting for model worker to start...'
while ! `cat model_worker.log | grep -q 'Uvicorn running on'`; do sleep 1; done
echo 'Starting openai api server...'
python -u -m fastchat.serve.openai_api_server --host 0.0.0.0 --port 8087 | tee ~/openai_api_server.log
This is awesome and an important update @cblmemo! The code looks good to me. One tests are passed, I think it should be good to go.
Thanks! Rerunning all smoke test now, will merge after all of them are passed 🫡
Fixed a bug introduced in #3484. Merging now
Fixed a bug introduced in #3484. Merging now
@cblmemo could you elaborate the bug a bit for future reference?
Fixed a bug introduced in #3484. Merging now
@cblmemo could you elaborate the bug a bit for future reference?
Sure. The smoke test test_skyserve_auto_restart uses GCP command to manually kill an instance, so it is requires that the replica is running on GCP. However in #3484 we accidentally removed the cloud: gcp in auto_restart.yaml. I added it back ; )
@cblmemo @Michaelvll This seems to have broken my service, which exposes the Gradio UI rather than the API itself. Here's what my service config looks like:
envs:
MODEL_NAME: meta-llama/Meta-Llama-3-8B-Instruct
HF_TOKEN: ...
service:
replica_policy:
min_replicas: 1
max_replicas: 3
target_qps_per_replica: 5
# An actual request for readiness probe.
readiness_probe:
initial_delay_seconds: 1800
path: /
resources:
cloud: aws
accelerators: A10G:1
disk_size: 512 # Ensure model checkpoints can fit.
ports: 8001 # Expose to internet traffic.
setup: |
conda activate vllm
if [ $? -ne 0 ]; then
conda create -n vllm python=3.10 -y
conda activate vllm
fi
pip install vllm==0.4.2
# Install Gradio for web UI.
pip install gradio openai
pip install flash-attn==2.5.9.post1
pip install numpy==1.26.4
run: |
conda activate vllm
echo 'Starting vllm api server...'
# https://github.com/vllm-project/vllm/issues/3098
export PATH=$PATH:/sbin
# NOTE: --gpu-memory-utilization 0.95 needed for 4-GPU nodes.
python -u -m vllm.entrypoints.openai.api_server \
--port 8000 \
--model $MODEL_NAME \
--trust-remote-code --tensor-parallel-size $SKYPILOT_NUM_GPUS_PER_NODE \
--gpu-memory-utilization 0.95 \
--max-num-seqs 64 \
2>&1 | tee api_server.log &
while ! `cat api_server.log | grep -q 'Uvicorn running on'`; do
echo 'Waiting for vllm api server to start...'
sleep 5
done
echo 'Starting gradio server...'
git clone https://github.com/vllm-project/vllm.git || true
python vllm/examples/gradio_openai_chatbot_webserver.py \
-m $MODEL_NAME \
--port 8001 \
--host 0.0.0.0 \
--model-url http://localhost:8000/v1 \
--stop-token-ids 128009,128001
The UI looks all messed up and I just get a bunch of error messages when trying to use it. If I open the URL of the instance directly, it works fine.
@cblmemo @Michaelvll This seems to have broken my service, which exposes the Gradio UI rather than the API itself. Here's what my service config looks like:
envs: MODEL_NAME: meta-llama/Meta-Llama-3-8B-Instruct HF_TOKEN: ... service: replica_policy: min_replicas: 1 max_replicas: 3 target_qps_per_replica: 5 # An actual request for readiness probe. readiness_probe: initial_delay_seconds: 1800 path: / resources: cloud: aws accelerators: A10G:1 disk_size: 512 # Ensure model checkpoints can fit. ports: 8001 # Expose to internet traffic. setup: | conda activate vllm if [ $? -ne 0 ]; then conda create -n vllm python=3.10 -y conda activate vllm fi pip install vllm==0.4.2 # Install Gradio for web UI. pip install gradio openai pip install flash-attn==2.5.9.post1 pip install numpy==1.26.4 run: | conda activate vllm echo 'Starting vllm api server...' # https://github.com/vllm-project/vllm/issues/3098 export PATH=$PATH:/sbin # NOTE: --gpu-memory-utilization 0.95 needed for 4-GPU nodes. python -u -m vllm.entrypoints.openai.api_server \ --port 8000 \ --model $MODEL_NAME \ --trust-remote-code --tensor-parallel-size $SKYPILOT_NUM_GPUS_PER_NODE \ --gpu-memory-utilization 0.95 \ --max-num-seqs 64 \ 2>&1 | tee api_server.log & while ! `cat api_server.log | grep -q 'Uvicorn running on'`; do echo 'Waiting for vllm api server to start...' sleep 5 done echo 'Starting gradio server...' git clone https://github.com/vllm-project/vllm.git || true python vllm/examples/gradio_openai_chatbot_webserver.py \ -m $MODEL_NAME \ --port 8001 \ --host 0.0.0.0 \ --model-url http://localhost:8000/v1 \ --stop-token-ids 128009,128001The UI looks all messed up and I just get a bunch of error messages when trying to use it. If I open the URL of the instance directly, it works fine.
Hi @anishchopra ! Thanks for the feedback. This is probably due to the proxy trying to send queries on both replicas, but it might contain some session related information which makes it error out; thus single replica deployment works well. Previously, a user will be redirected to one replica and all interactions are made with the redirected replica. Currently, we would suggest using SkyServe to host your API endpoint and launching the Gradio server manually to point to the service endpoint. If you have further suggestions or requirements, could you help filing an issue for this? Thanks!
@cblmemo I actually only had one replica running in this case. Your suggestion of launching a gradio server separately does work, however I bring up this issue because it points to something not being proxied correctly.
@cblmemo I actually only had one replica running in this case. Your suggestion of launching a gradio server separately does work, however I bring up this issue because it points to something not being proxied correctly.
Humm, could you share the output of sky -v and sky -c? I actually tried on the latest master and it works well, so maybe it is a version issue. And thanks for pointing it out! I just filed an issue #3749 to keep track of this :))