llama.cpp
llama.cpp copied to clipboard
Overlap CUDA graph building and processing to minimize GPU idle time and improve tokens per seconds performance.
Hi all,
this PR takes the ideas applied to the vulkan backend (#9118 and #10499) and implements them for CUDA. This results in improved tokens per second performance.
Performance
I tested on two systems using an example query in llama-cli
and the phi3-mini-4k-instruct
model.
Prompt eval tokens per second improved between 2.5 and 7%.
Context print tokens per second improved between 2.8 and 3.57%
Note that this is a PR to reduce CPU overhead, and that these numbers were generated using top-end CPUs.
On less powerful consumer CPUs, the performance increase will be more significant.
Explanation
Currently, before every forward pass, a CUDA graph is built on CPU and then executed on GPU. This results in a delay, the GPU needs to wait around for the CPU to finish CUDA graph building.
Our proposed change splits the CPU workload into smaller pieces, with the effect that after the first graph has been built, the CPU and GPU can work in parallel on different CUDA graphs.
The before/after is shown in the below images from nsight systems
. Top is the master, bottom is this changeset.
The time between the start of the forward pass (red/green timeline of the CUDA API) and GPU graph execution (orange) is measured. We highlighted the time taken (256us vs 56us) with a red circle. This seems small, but as this is done before each forward pass / token generation step, this adds up quickly.
Note that both screenshots are made at different time-scales, the width of the items itself is misleading. Only the measured time is relevant, and the pattern of the red/green operations of the CUDA-API.
Performance impact of switching between graphs during forward passes
My code mirrors the changes in vulkan. In our testing, each forward pass is done with dozens of graphs. One could argue that the last few context switches likely are not required and hinder performance.
We investigated this. Switching between these is a non-issue for now, at about 2us per switch. However we could discuss strategies to steadily increase the graph size to reduce the number of context switches.
@mtavenrath @agray3