vllm icon indicating copy to clipboard operation
vllm copied to clipboard

[Core][Bugfix][Perf] Introduce `MQLLMEngine` to avoid `asyncio` OH

Open alexm-redhat opened this issue 5 months ago • 19 comments

SUMMARY:

  • Removed almost all the overhead from the OpenAI server, but still saw significant slowdown running in AsyncLLMEngine rather than LLMEngine on H100, including when we ran "headless" (e.g. with no uvicorn server).
  • NOTE: performance varies by quality of the CPU. Impact of asyncio on DGX is much lower
  • This caused us to believe the asyncio event loop in AsyncLLMEngine was the root cause of the slowdown
  • This PR replaces AsyncLLMEngine with MPLLMEngine. MPLLMEngine works similarly to AsyncLLMEngine (i.e. it runs a background loop, accepts new requests, and streams requests back to the clients). We use zeromq as the message passing protocol rather than pulling from queues and pushing to generators
  • This PR also fixes the number of sockets in use by the RPCClient, avoiding all issues with Too Many Open Files

Summary Performance vs Offline:

pr scenario offline serving slowdown %
main multistep 42.9 33.3 -22%
pr multistep 42.9 40.3 -6%
main single-step 34.8 14.9 -55%
pr single-step 34.8 31.6 -9%
  • NOTE: the multistep performance on main is "less-bad" because we currently only stream 1/8 tokens. Once we enable incremental streaming, the performance will be closer to -50% on main.

  • NOTE: there is still some remaining performance to get by switching the inner loop to using protobufs

Multistep Performance

1xH100 PERFORMANCE BASELINE:

MODEL="meta-llama/Meta-Llama-3.1-8B-Instruct"
python3 benchmarks/benchmark_throughput.py --model $MODEL --dataset benchmarks/ShareGPT_V3_unfiltered_cleaned_split.json --num-scheduler-steps 8

1xH100 SERVING PERFORMANCE

  • Client:
MODEL="meta-llama/Meta-Llama-3.1-8B-Instruct"
python3 benchmarks/benchmark_serving.py --model $MODEL --dataset-path benchmarks/ShareGPT_V3_unfiltered_cleaned_split.json 
  • Server:
MODEL="meta-llama/Meta-Llama-3.1-8B-Instruct"
vllm serve $MODEL --disable-log-requests --num-scheduler-steps 8 --max-model-len 8192

SUMMARY

branch setup throughput
main offline 42.9
main serving mp 33.3
main serving --disable-frontend-multiprocessing 28.9
pr serving mp 40.3
pr serving --disable-frontend-multiprocessing 27.9

Single-Step Performance

1xH100 PERFORMANCE BASELINE:

MODEL="meta-llama/Meta-Llama-3.1-8B-Instruct"
python3 benchmarks/benchmark_throughput.py --model $MODEL --dataset benchmarks/ShareGPT_V3_unfiltered_cleaned_split.json

1xH100 SERVING PERFORMANCE

  • Client:
MODEL="meta-llama/Meta-Llama-3.1-8B-Instruct"
python3 benchmarks/benchmark_serving.py --model $MODEL --dataset-path benchmarks/ShareGPT_V3_unfiltered_cleaned_split.json 
  • Server:
MODEL="meta-llama/Meta-Llama-3.1-8B-Instruct"
vllm serve $MODEL --disable-log-requests --max-model-len 8192

SUMMARY

branch setup throughput
main offline 34.8
main serving mp 14.9
pr serving mp 31.6

TODOS:

  • battle test with high load (abort | HWM?)
  • improve robustness (startup / teardown)
  • feature set (ray, non-gpu, pipeline parallel, profiler)
  • hardening / unit tests for MPLLMEngine
  • tests passing

co-authored by @robertgshaw2-neuralmagic

FIX https://github.com/vllm-project/vllm/issues/7920

BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

alexm-redhat avatar Sep 04 '24 15:09 alexm-redhat