vllm
vllm copied to clipboard
Custom all reduce kernels
See this doc for detailed writeup and experiments
Latency-optimal allreduce and cuda graph optimization.pdf
Latency and memory
Tested with
python benchmarks/benchmark_latency.py \
--model NAME -tp TP --input-len 256 --output-len 256 --batch-size BS --num-iters 5
L = Latency, M = Memory
| Model | GPU | TP | BS | L before (s) | L after (s) | M before (MB) | M after (MB) |
|---|---|---|---|---|---|---|---|
| Llama 70B | A100-80G | 4 | 64 | 16.37 | 14.57 | 76723 | 75717 |
| Llama 33B | A100-80G | 2 | 64 | 15.88 | 14.38 | 75739 | 74741 |
| Llama 13B | A30 (no nvlink) | 2 | 32 | 10.43 | 9.79 | 23519 | 22911 |
| Llama 7B | T4 | 2 | 32 | 17.90 | 17.51 | 14969 | 14475 |
Hypothesis on why memory usage is lower with fast allreduce:
- NCCL's internal buffer is captured in the graph
- NCCL requires inserting more nodes per invocation. For example, NCCL requires a few host nodes to ensure proper operations.
Throughput
| Model | GPU | TP | Throughput before | Throughput after |
|---|---|---|---|---|
| Llama 70B | A100-80G | 4 | 3.68 requests/s, 1761.33 tokens/s | 3.87 requests/s, 1852.65 tokens/s |
Performance and memory note
- NVswitch based systems should observe higher performance improvement than PCIe systems. Generally, the faster the link, the higher the performance improvement.
- Latency improvement is more significant at smaller batch sizes, when allreduce is more latency bound.
- The smaller memory overhead of fast allreduce can lead to higher throughput and alleviate some OOM issues when GPU memory budget is tight (e.g. serving 33b with 4xA30).
Implementation note
Since I originally implemented fast allreduce on top of my own fork, I made some changes compared to the original one in the doc. Note that the performance numbers in the writeup doc are not valid because my fork differs significantly from the upstream. Main changes are
- No fusion with residual connection: this is because it's already fused with layernorm
- No cuda graph replay optimizations. @WoosukKwon's cuda graph implementation uses a single graph launch per model only (mine needs one per layer), so that's probably not necessary.
There are also extensive effort made to make it work with cuda graph automatically (automatic IPC buffer registration). My previous implementation requires manually allocating a global buffer and changing model code to write matmul's output to it.
The one-hop and two-hop all reduce kernels work very similar to Nvidia TensorRT-LLM's kernels (https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.5.0/cpp/tensorrt_llm/kernels/customAllReduceKernels.cu). However, there were developed independently before TensorRT-LLM's release
Note on some source files
- fast_allreduce.cuh is the implementation without pytorch dependencies. You can copy this file to compiler explorer and check the PTX/SASS
- fast_allreduce_test.cu is a C++ test for performance and accuracy comparison between NCCL and my implementation. It's fast to compile compared to the torch extension. Code there isn't very neat.
Caveats
Compared to NCCL allreduce, there are some caveats for the fast allreduce path.
- Only work when cuda graph is enabled
- Only work for tensor whose byte size is multiple of 16
- Can only work out-of-place for now.
- Doesn't work with hybrid parallelism for now (e.g. TP + PP). I don't know if there are planned to be supported with vLLM
1 and 2 should be automatically handled and checked. 3 should be a non-issue since all usage of tensor_model_parallelism uses its return value.
TODOs
- [x] add configuration option
- [x] more end-to-end performance testing on other GPUs, model sizes and TP configs
- [x] end-to-end correctness test with models
- [x] format code
[ ] (maybe) nit: bind C++ class properly with pybind (not using C style binding)Since we don't want to introduce pytorch dependencies to the header file, we need an additional layer of wrapper anyway.
It's not quite ready to merge. I'm requesting for comments.
cc @WoosukKwon @simon-mo
@hanzhi713 This is awesome! Many thanks for the PR! A quick question: do you happen to know about the custom all-reduce kernels in TRT-LLM? Is this PR related to the kernel?
This is included in the PR description
The one-hop and two-hop all reduce kernels work very similar to Nvidia TensorRT-LLM's kernels (https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.5.0/cpp/tensorrt_llm/kernels/customAllReduceKernels.cu). However, there were developed independently before TensorRT-LLM's release
@WoosukKwon Correctness and functionality wise this PR should be ready. Checked a few models and there are only occasional generation differences (due to numerical differences). See the diff below for reference. Left is without fast allreduce and right is with fast allreduce.
https://www.diffchecker.com/hiJejMpy/
Tested with
from vllm import LLM, SamplingParams
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
] * 32
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=32)
# Create an LLM.
llm = LLM(model="TheBloke/Llama-2-70B-fp16", tensor_parallel_size=8, disable_fast_allreduce=True) # or False for fast allreduce
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
A 5% throughput improvement is quite impressive from optimizing all reduce with custom kernels. Well done!
A 5% throughput improvement is quite impressive from optimizing all reduce with custom kernels. Well done!
Yes, considering this is mainly an latency optimization
@hanzhi713 have you compared https://github.com/pytorch/pytorch/pull/114001 with your custom reduce ops?
@hanzhi713 BTW I got this error when using 2 L4 GPUs:
(RayWorkerVllm pid=51757) INFO 12-26 04:18:45 fast_allreduce.py:21] NVLink detection failed with message "Not Supported". This is normal if your machine has no NVLink equipped
(RayWorkerVllm pid=51757) Failed: Cuda error /home/gcpuser/workspace/vllm/csrc/fast_allreduce.cuh:368 'peer access is not supported between these two devices'
@hanzhi713 BTW I got this error when using 2 L4 GPUs:
(RayWorkerVllm pid=51757) INFO 12-26 04:18:45 fast_allreduce.py:21] NVLink detection failed with message "Not Supported". This is normal if your machine has no NVLink equipped (RayWorkerVllm pid=51757) Failed: Cuda error /home/gcpuser/workspace/vllm/csrc/fast_allreduce.cuh:368 'peer access is not supported between these two devices'
I guess I have to check this. While all topologies that I have access to support P2P, some platforms don't.
@hanzhi713 have you compared pytorch/pytorch#114001 with your custom reduce ops?
I took a glimpse and I would say performance would be similar (essentially the same idea, again!). Main difference would be that I rely on cuda graph to avoid a separate call to cudaMemcpyAsync into the allreduce buffer.
@hanzhi713 BTW I got this error when using 2 L4 GPUs:
(RayWorkerVllm pid=51757) INFO 12-26 04:18:45 fast_allreduce.py:21] NVLink detection failed with message "Not Supported". This is normal if your machine has no NVLink equipped (RayWorkerVllm pid=51757) Failed: Cuda error /home/gcpuser/workspace/vllm/csrc/fast_allreduce.cuh:368 'peer access is not supported between these two devices'I guess I have to check this. While all topologies that I have access to support P2P, some platforms don't.
@WoosukKwon Can you try this again? I added a detection so this won't fail fatally. I couldn't test this myself.
@hanzhi713 I still got the same error on 2 L4 GPUs.
(RayWorkerVllm pid=70031) INFO 12-27 08:52:31 fast_allreduce.py:70] NVLink detection failed with message "Not Supported". This is normal if your machine has no NVLink equipped
(RayWorkerVllm pid=70031) Failed: Cuda error /home/gcpuser/sky_workdir/vllm/csrc/fast_allreduce.cuh:368 'peer access is not supported between these two devices'
(raylet) A worker died or was killed while executing a task by an unexpected system error. To troubleshoot the problem, check the logs for the dead worker. RayTask ID: ffffffffffffffff8bd892ecdd93f11a79ac738701000000 Worker ID: 37df96f805dda76d6075808cf58b47652f36484fc2eb0a9dac5875f1 Node ID: 841dcb4f7e8438606c0420de9a5c342767f613cff646250ec2ea32c6 Worker IP address: 10.140.0.4 Worker port: 44313 Worker PID: 70032 Worker exit type: SYSTEM_ERROR Worker exit detail: Worker unexpectedly exits with a connection error code 2. End of file. There are some potential root causes. (1) The process is killed by SIGKILL by OOM killer due to high memory usage. (2) ray stop --force is called. (3) The worker is crashed unexpectedly due to SIGSEGV or other unexpected errors.
(RayWorkerVllm pid=70032) /home/gcpuser/sky_workdir/vllm/vllm/model_executor/parallel_utils/fast_allreduce.py:121: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
(RayWorkerVllm pid=70032) data = inp.storage()._share_cuda_()
Short sequences is OK. When run long sequences (10k+) repeatly, it is slow. And sometimes stuck.
Short sequences is OK. When run long sequences (10k+) repeatly, it is slow. And sometimes stuck.
Is it working when disable_fast_allreduce set to True?
Short sequences is OK. When run long sequences (10k+) repeatly, it is slow. And sometimes stuck.
I tried this
python benchmarks/benchmark_latency.py --model TheBloke/Llama-2-70B-fp16 -tp 4 --input-len 256 --output-len 20480 --batch-size 1 --num-iters 3
and latency improved about 9% when fast allreduce is enabled. Didn't observe your problem. What's your model and platform? And did disable_fast_allreduce solve your problem?
@hanzhi713 I still got the same error on 2 L4 GPUs.
(RayWorkerVllm pid=70031) INFO 12-27 08:52:31 fast_allreduce.py:70] NVLink detection failed with message "Not Supported". This is normal if your machine has no NVLink equipped (RayWorkerVllm pid=70031) Failed: Cuda error /home/gcpuser/sky_workdir/vllm/csrc/fast_allreduce.cuh:368 'peer access is not supported between these two devices' (raylet) A worker died or was killed while executing a task by an unexpected system error. To troubleshoot the problem, check the logs for the dead worker. RayTask ID: ffffffffffffffff8bd892ecdd93f11a79ac738701000000 Worker ID: 37df96f805dda76d6075808cf58b47652f36484fc2eb0a9dac5875f1 Node ID: 841dcb4f7e8438606c0420de9a5c342767f613cff646250ec2ea32c6 Worker IP address: 10.140.0.4 Worker port: 44313 Worker PID: 70032 Worker exit type: SYSTEM_ERROR Worker exit detail: Worker unexpectedly exits with a connection error code 2. End of file. There are some potential root causes. (1) The process is killed by SIGKILL by OOM killer due to high memory usage. (2) ray stop --force is called. (3) The worker is crashed unexpectedly due to SIGSEGV or other unexpected errors. (RayWorkerVllm pid=70032) /home/gcpuser/sky_workdir/vllm/vllm/model_executor/parallel_utils/fast_allreduce.py:121: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage() (RayWorkerVllm pid=70032) data = inp.storage()._share_cuda_()
I need your help here writing a detection that works for your case (or non P2P-capable platforms in general).
@hanzhi713 have you compared pytorch/pytorch#114001 with your custom reduce ops?
I took a glimpse and I would say performance would be similar (essentially the same idea, again!). Main difference would be that I rely on cuda graph to avoid a separate call to
cudaMemcpyAsyncinto the allreduce buffer.
I build and try this feature in torch main branch cod9c0e37bab9462c18508a594659cd34a66abfe1e. Use ENABLE_INTRA_NODE_COMM environment can be faster than this pr. Maybe you can compare with it.
@hanzhi713 have you compared pytorch/pytorch#114001 with your custom reduce ops?
I took a glimpse and I would say performance would be similar (essentially the same idea, again!). Main difference would be that I rely on cuda graph to avoid a separate call to
cudaMemcpyAsyncinto the allreduce buffer.I build and try this feature in torch main branch cod9c0e37bab9462c18508a594659cd34a66abfe1e. Use ENABLE_INTRA_NODE_COMM environment can be faster than this pr. Maybe you can compare with it.
It's not in the nightly yet (due to breaking tests according to the PR discussion). I don't bother building torch from scratch.
Very good work, may I ask if it can be merged into the main branch soon
@WoosukKwon Did you finish your review? Regarding the L4 issue, if you can write a detection that works, I would appreciate that. Otherwise I can try to get my hand on some L4 machines, but that may take a while.
@resorcap Actually, it is in nightly pytorch, just a different PR: https://github.com/pytorch/pytorch/pull/116125. I built this PR with torch nightly (2.3.0.dev20240103) and xformers nightly, and tested it against pytorch's fast allreduce. Note that pytorch allreduce only supports bfloat16 (doesn't make sense why no float16 support?)
export ENABLE_INTRA_NODE_COMM=1
python benchmarks/benchmark_latency.py --model TheBloke/Llama-2-70B-fp16 -tp 4 --input-len 256 --output-len 256 --batch-size 64 --num-iters 5 --dtype bfloat16
Pytorch fast allreduce: Avg latency: 13.739070991426706 seconds
My fast allreduce: Avg latency: 13.681239241734147 seconds
As expected, performance would be similar, except that pytorch's version needs an cudaMemcpyAsync before each allreduce which makes it slightly slower.
@WoosukKwon Any progress here?
@hanzhi713
I built vllm from your branch and ran the throughput benchmark, but I didn't reproduce the performance improvement you claimed.
May I ask if I need to use a special Flag during the build process to enable your optimization?
My device is equipped with eight NVIDIA A100s with NVLink.
python3 ./benchmark_throughput.py --backend vllm --tokenizer /tmp/ramdisk/llama-70b-hf/ --dataset ./test.json --model /tmp/ramdisk/llama-70b-hf/ --tensor-parallel-size 8
vllm-project:main:
Throughput: 6.66 requests/s, 3219.35 tokens/s
hanzhi713:fast_ar_sq:
Throughput: 6.66 requests/s, 3218.49 tokens/s
@AethoceSora I updated my branch with the latest commits from main, and tested with the following command
python benchmarks/benchmark_throughput.py --model TheBloke/Llama-2-70B-fp16 -tp 8 --dataset benchmarks/sharegpt.json
Without fast allreduce:
Throughput: 5.85 requests/s, 2798.38 tokens/s
With fast allreduce:
Throughput: 6.30 requests/s, 3011.72 tokens/s
However, since this PR is mainly a latency optimization, you'll see less performance gain at larger batch sizes such as when using the throughput benchmark. In practical serving scenarios, we're usually not 100% throughput bound (otherwise latency might suffer), and you'll see more gains.
I tested the throughput benchmark in llama-2-70B using a 10k size dataset and achieved performance improvements.
vllm-project:main:
Throughput: 6.66 requests/s, 3219.35 tokens/s
hanzhi713:fast_ar_sq:
Throughput: 7.02 requests/s, 3390.49 tokens/s
Considering that this is a large-scale throughput benchmark, the optimization effect is as expected.
Thank you for your excellent work!
Very nice! I ran MMLU on mixtral with TP8 on this PR as an end-to-end check for the correctness and the results look good:
| Groups |Version|Filter|n-shot|Metric|Value | |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu |N/A |none | 0|acc |0.7033|± |0.1403|
| - humanities |N/A |none | 5|acc |0.6457|± |0.1572|
| - other |N/A |none | 5|acc |0.7744|± |0.1102|
| - social_sciences|N/A |none | 5|acc |0.8096|± |0.0711|
| - stem |N/A |none | 5|acc |0.6156|± |0.1406|
Also here are a few latency measurements on mixtral, TP 8 in different batch size regimes:
with this PR:
bs = 1:
Avg ITL: 10.36 milliseconds
bs = 2:
Avg ITL: 10.87 milliseconds
bs = 4:
Avg ITL: 11.33 milliseconds
bs = 8:
Avg ITL: 12.08 milliseconds
bs = 16:
Avg ITL: 13.45 milliseconds
bs = 32:
Avg ITL: 16.69 milliseconds
bs = 64:
Avg ITL: 22.51 milliseconds
bs = 128:
Avg ITL: 34.99 milliseconds
without the PR:
bs = 1:
Avg ITL: 11.59 milliseconds
bs = 2:
Avg ITL: 12.10 milliseconds
bs = 4:
Avg ITL: 12.59 milliseconds
bs = 8:
Avg ITL: 13.33 milliseconds
bs = 16:
Avg ITL: 14.63 milliseconds
bs = 32:
Avg ITL: 17.60 milliseconds
bs = 64:
Avg ITL: 24.33 milliseconds
bs = 128:
Avg ITL: 36.91 milliseconds
Nice improvements across the board, great job and thanks for the contribution!
@hanzhi713 Apologies for the delay. I had some personal issues for the last couple of weeks. I will review the PR today. BTW, one small concern on my side is the name "fast all reduce." While it indeed brings a noticeable performance boost, I feel putting words like "fast" or "optimized" in the name of a function assumes that the "fast" path is actually faster. That might be true now, but later changes (in the code, compilers, or GPU hardware) might make it false. In that case, your name could be unintentionally misleading. Can we call it "custom all reduce" or something like that?
@hanzhi713 Apologies for the delay. I had some personal issues for the last couple of weeks. I will review the PR today. BTW, one small concern on my side is the name "fast all reduce." While it indeed brings a noticeable performance boost, I feel putting words like "fast" or "optimized" in the name of a function assumes that the "fast" path is actually faster. That might be true now, but later changes (in the code, compilers, or GPU hardware) might make it false. In that case, your name could be unintentionally misleading. Can we call it "custom all reduce" or something like that?
Sure I will make the change
@WoosukKwon @Yard1 FYI I've added support for eager mode in the latest few commits. This is done by adding a cudaMemcpy to an IPC-registered buffer in eager mode.