skypilot icon indicating copy to clipboard operation
skypilot copied to clipboard

[Serve] Proxy w/ retry

Open cblmemo opened this issue 1 year ago • 5 comments

Using proxy on our load balancer w/ retry. It is useful for spot-based serving.

TODO:

  • [x] Test and make sure it works
  • [x] Benchmark, compare w/ previous implementation
  • [x] Support streaming
  • [x] Make it more robust
  • [x] Update documentation

Tested (run the relevant ones):

  • [x] Code formatting: bash format.sh
  • [x] Any manual or new tests for this PR (please specify below)
import fastapi, uvicorn, asyncio
import multiprocessing
from sky.serve import load_balancer

REPLICA_URLS = []
PROCESSES = []
CONTROLLER_PORT = 20018
WORD_TO_STREAM = 'Hello world! Nice to meet you!'
TIME_TO_SLEEP = 0.2

def _start_streaming_replica(port):
    app = fastapi.FastAPI()

    @app.get('/')
    async def stream():
        async def generate_words():
            for word in WORD_TO_STREAM.split():
                yield word + "\n"
                await asyncio.sleep(TIME_TO_SLEEP)
        
        return fastapi.responses.StreamingResponse(generate_words(), media_type="text/plain")

    @app.get('/non-stream')
    async def non_stream():
        return {'message': WORD_TO_STREAM}

    @app.get('/error')
    async def error():
        raise fastapi.HTTPException(status_code=500, detail='Internal Server Error')

    uvicorn.run(app, host='0.0.0.0', port=port)

def _start_streaming_replica_in_process(port):
    global PROCESSES, REPLICA_URLS
    STREAMING_REPLICA_PROCESS = multiprocessing.Process(target=_start_streaming_replica, args=(port,))
    STREAMING_REPLICA_PROCESS.start()
    PROCESSES.append(STREAMING_REPLICA_PROCESS)
    REPLICA_URLS.append(f'0.0.0.0:{port}')

def _start_controller():
    app = fastapi.FastAPI()

    @app.post('/controller/load_balancer_sync')
    async def lb_sync(request: fastapi.Request):
        return {'ready_replica_urls': REPLICA_URLS}

    uvicorn.run(app, host='0.0.0.0', port=CONTROLLER_PORT)

def _start_controller_in_process():
    global PROCESSES
    CONTROLLER_PROCESS = multiprocessing.Process(target=_start_controller)
    CONTROLLER_PROCESS.start()
    PROCESSES.append(CONTROLLER_PROCESS)

if __name__ == '__main__':
    try:
        _start_streaming_replica_in_process(7001)
        _start_streaming_replica_in_process(7002)
        _start_streaming_replica_in_process(7003)
        _start_streaming_replica_in_process(7004)
        _start_controller_in_process()
        lb = load_balancer.SkyServeLoadBalancer(
            controller_url=f'http://0.0.0.0:{CONTROLLER_PORT}',
            load_balancer_port=7000)
        lb.run()
    finally:
        for p in PROCESSES:
            p.terminate()
  • [x] All skyserve smoke tests: pytest tests/test_smoke.py --serve
  • [ ] Relevant individual smoke tests: pytest tests/test_smoke.py::test_fill_in_the_name
  • [ ] Backward compatibility tests: bash tests/backward_comaptibility_tests.sh

cblmemo avatar Mar 31 '24 07:03 cblmemo

@Michaelvll This is ready for a look now 🫡 I'm still running smoke tests and adding a new streaming test for now, will report back later

cblmemo avatar May 01 '24 15:05 cblmemo

All skyserve smoke test passed 🫡

cblmemo avatar May 03 '24 08:05 cblmemo

TODO:

  • [x] investigate side effects for max num connections = 1000
  • [x] test if an abortion happens on the client side, if the worker will stop generate
  • [x] run client.aclose in background; https://github.com/skypilot-org/skypilot/pull/2735#discussion_r1594826172

cblmemo avatar May 09 '24 00:05 cblmemo

Tested for abortion and it works as well. Use the following script to launch LB & worker, and http://0.0.0.0:7000/, then Ctrl+C the curl command. The logging for ===========WORKER will stop.

UPD: I'll test e2e LLM workloads later.

import fastapi, uvicorn, asyncio
import multiprocessing
from sky.serve import load_balancer

REPLICA_URLS = []
PROCESSES = []
CONTROLLER_PORT = 20018
WORD_TO_STREAM = 'Hello world! Nice to meet you!'
TIME_TO_SLEEP = 0.2

def _start_streaming_replica(port):
    app = fastapi.FastAPI()

    @app.get('/')
    async def stream():
        async def generate_words():
            for word in WORD_TO_STREAM.split()*1000:
                yield word + "\n"
                print('===========WORKER', word)
                await asyncio.sleep(TIME_TO_SLEEP)
        
        return fastapi.responses.StreamingResponse(generate_words(), media_type="text/plain")

    @app.get('/non-stream')
    async def non_stream():
        return {'message': WORD_TO_STREAM}

    @app.get('/error')
    async def error():
        raise fastapi.HTTPException(status_code=500, detail='Internal Server Error')

    uvicorn.run(app, host='0.0.0.0', port=port)

def _start_streaming_replica_in_process(port):
    global PROCESSES, REPLICA_URLS
    STREAMING_REPLICA_PROCESS = multiprocessing.Process(target=_start_streaming_replica, args=(port,))
    STREAMING_REPLICA_PROCESS.start()
    PROCESSES.append(STREAMING_REPLICA_PROCESS)
    REPLICA_URLS.append(f'http://0.0.0.0:{port}')

def _start_controller():
    app = fastapi.FastAPI()
    flip_flop = False

    @app.post('/controller/load_balancer_sync')
    async def lb_sync(request: fastapi.Request):
        return {'ready_replica_urls': REPLICA_URLS}

    uvicorn.run(app, host='0.0.0.0', port=CONTROLLER_PORT)

def _start_controller_in_process():
    global PROCESSES
    CONTROLLER_PROCESS = multiprocessing.Process(target=_start_controller)
    CONTROLLER_PROCESS.start()
    PROCESSES.append(CONTROLLER_PROCESS)

if __name__ == '__main__':
    try:
        _start_streaming_replica_in_process(7001)
        _start_controller_in_process()
        lb = load_balancer.SkyServeLoadBalancer(
            controller_url=f'http://0.0.0.0:{CONTROLLER_PORT}',
            load_balancer_port=7000)
        lb.run()
    finally:
        for p in PROCESSES:
            p.terminate()

cblmemo avatar May 09 '24 08:05 cblmemo

Just tested with a modified version of fastchat and the abortion works well. I uses the OpenAI Client here and manually Ctrl+C to abort the request.

YAML i used:

service:
  readiness_probe: /v1/models
  replicas: 1

resources:
  ports: 8087
  memory: 32+
  accelerators: L4:1
  disk_size: 1024
  disk_tier: best

envs:
  MODEL_SIZE: 7
  HF_TOKEN: <huggingface-token>

setup: |
  conda activate chatbot
  if [ $? -ne 0 ]; then
    conda create -n chatbot python=3.9 -y
    conda activate chatbot
  fi

  # Install dependencies
  git clone https://github.com/cblmemo/fschat-print-streaming.git fschat
  cd fschat
  git switch print-stream
  pip install -e ".[model_worker,webui]"
  python -c "import huggingface_hub; huggingface_hub.login('${HF_TOKEN}')"

run: |
  conda activate chatbot
  
  echo 'Starting controller...'
  python -u -m fastchat.serve.controller --host 0.0.0.0 > ~/controller.log 2>&1 &
  sleep 10
  echo 'Starting model worker...'
  python -u -m fastchat.serve.model_worker --host 0.0.0.0 \
            --model-path meta-llama/Llama-2-${MODEL_SIZE}b-chat-hf \
            --num-gpus $SKYPILOT_NUM_GPUS_PER_NODE 2>&1 \
            | tee model_worker.log &

  echo 'Waiting for model worker to start...'
  while ! `cat model_worker.log | grep -q 'Uvicorn running on'`; do sleep 1; done

  echo 'Starting openai api server...'
  python -u -m fastchat.serve.openai_api_server --host 0.0.0.0 --port 8087 | tee ~/openai_api_server.log

cblmemo avatar May 09 '24 17:05 cblmemo

This is awesome and an important update @cblmemo! The code looks good to me. One tests are passed, I think it should be good to go.

Thanks! Rerunning all smoke test now, will merge after all of them are passed 🫡

cblmemo avatar May 14 '24 06:05 cblmemo

Fixed a bug introduced in #3484. Merging now

cblmemo avatar May 14 '24 13:05 cblmemo

Fixed a bug introduced in #3484. Merging now

@cblmemo could you elaborate the bug a bit for future reference?

Michaelvll avatar May 14 '24 15:05 Michaelvll

Fixed a bug introduced in #3484. Merging now

@cblmemo could you elaborate the bug a bit for future reference?

Sure. The smoke test test_skyserve_auto_restart uses GCP command to manually kill an instance, so it is requires that the replica is running on GCP. However in #3484 we accidentally removed the cloud: gcp in auto_restart.yaml. I added it back ; )

cblmemo avatar May 14 '24 16:05 cblmemo

@cblmemo @Michaelvll This seems to have broken my service, which exposes the Gradio UI rather than the API itself. Here's what my service config looks like:

envs:
  MODEL_NAME: meta-llama/Meta-Llama-3-8B-Instruct
  HF_TOKEN: ...

service:
  replica_policy:
    min_replicas: 1
    max_replicas: 3
    target_qps_per_replica: 5

  # An actual request for readiness probe.
  readiness_probe:
    initial_delay_seconds: 1800
    path: /
 
resources:
  cloud: aws
  accelerators: A10G:1
  disk_size: 512  # Ensure model checkpoints can fit.
  ports: 8001  # Expose to internet traffic.

setup: |
  conda activate vllm
  if [ $? -ne 0 ]; then
    conda create -n vllm python=3.10 -y
    conda activate vllm
  fi

  pip install vllm==0.4.2
  # Install Gradio for web UI.
  pip install gradio openai
  pip install flash-attn==2.5.9.post1
  pip install numpy==1.26.4


run: |
  conda activate vllm
  echo 'Starting vllm api server...'

  # https://github.com/vllm-project/vllm/issues/3098
  export PATH=$PATH:/sbin

  # NOTE: --gpu-memory-utilization 0.95 needed for 4-GPU nodes.
  python -u -m vllm.entrypoints.openai.api_server \
    --port 8000 \
    --model $MODEL_NAME \
    --trust-remote-code --tensor-parallel-size $SKYPILOT_NUM_GPUS_PER_NODE \
    --gpu-memory-utilization 0.95 \
    --max-num-seqs 64 \
    2>&1 | tee api_server.log &

  while ! `cat api_server.log | grep -q 'Uvicorn running on'`; do
    echo 'Waiting for vllm api server to start...'
    sleep 5
  done

  echo 'Starting gradio server...'
  git clone https://github.com/vllm-project/vllm.git || true
  python vllm/examples/gradio_openai_chatbot_webserver.py \
    -m $MODEL_NAME \
    --port 8001 \
    --host 0.0.0.0 \
    --model-url http://localhost:8000/v1 \
    --stop-token-ids 128009,128001

The UI looks all messed up and I just get a bunch of error messages when trying to use it. If I open the URL of the instance directly, it works fine.

anishchopra avatar Jul 04 '24 05:07 anishchopra

@cblmemo @Michaelvll This seems to have broken my service, which exposes the Gradio UI rather than the API itself. Here's what my service config looks like:

envs:
  MODEL_NAME: meta-llama/Meta-Llama-3-8B-Instruct
  HF_TOKEN: ...

service:
  replica_policy:
    min_replicas: 1
    max_replicas: 3
    target_qps_per_replica: 5

  # An actual request for readiness probe.
  readiness_probe:
    initial_delay_seconds: 1800
    path: /
 
resources:
  cloud: aws
  accelerators: A10G:1
  disk_size: 512  # Ensure model checkpoints can fit.
  ports: 8001  # Expose to internet traffic.

setup: |
  conda activate vllm
  if [ $? -ne 0 ]; then
    conda create -n vllm python=3.10 -y
    conda activate vllm
  fi

  pip install vllm==0.4.2
  # Install Gradio for web UI.
  pip install gradio openai
  pip install flash-attn==2.5.9.post1
  pip install numpy==1.26.4


run: |
  conda activate vllm
  echo 'Starting vllm api server...'

  # https://github.com/vllm-project/vllm/issues/3098
  export PATH=$PATH:/sbin

  # NOTE: --gpu-memory-utilization 0.95 needed for 4-GPU nodes.
  python -u -m vllm.entrypoints.openai.api_server \
    --port 8000 \
    --model $MODEL_NAME \
    --trust-remote-code --tensor-parallel-size $SKYPILOT_NUM_GPUS_PER_NODE \
    --gpu-memory-utilization 0.95 \
    --max-num-seqs 64 \
    2>&1 | tee api_server.log &

  while ! `cat api_server.log | grep -q 'Uvicorn running on'`; do
    echo 'Waiting for vllm api server to start...'
    sleep 5
  done

  echo 'Starting gradio server...'
  git clone https://github.com/vllm-project/vllm.git || true
  python vllm/examples/gradio_openai_chatbot_webserver.py \
    -m $MODEL_NAME \
    --port 8001 \
    --host 0.0.0.0 \
    --model-url http://localhost:8000/v1 \
    --stop-token-ids 128009,128001

The UI looks all messed up and I just get a bunch of error messages when trying to use it. If I open the URL of the instance directly, it works fine.

Hi @anishchopra ! Thanks for the feedback. This is probably due to the proxy trying to send queries on both replicas, but it might contain some session related information which makes it error out; thus single replica deployment works well. Previously, a user will be redirected to one replica and all interactions are made with the redirected replica. Currently, we would suggest using SkyServe to host your API endpoint and launching the Gradio server manually to point to the service endpoint. If you have further suggestions or requirements, could you help filing an issue for this? Thanks!

cblmemo avatar Jul 07 '24 06:07 cblmemo

@cblmemo I actually only had one replica running in this case. Your suggestion of launching a gradio server separately does work, however I bring up this issue because it points to something not being proxied correctly.

anishchopra avatar Jul 10 '24 03:07 anishchopra

@cblmemo I actually only had one replica running in this case. Your suggestion of launching a gradio server separately does work, however I bring up this issue because it points to something not being proxied correctly.

Humm, could you share the output of sky -v and sky -c? I actually tried on the latest master and it works well, so maybe it is a version issue. And thanks for pointing it out! I just filed an issue #3749 to keep track of this :))

cblmemo avatar Jul 13 '24 10:07 cblmemo