llama.cpp icon indicating copy to clipboard operation
llama.cpp copied to clipboard

New optimization from NVIDIA to use CUDA Graphs in llama.cpp

Open agray3 opened this issue 2 months ago • 9 comments

Great work everyone on llama.cpp! I am Alan Gray, a developer technology engineer from NVIDIA, and have developed an optimization to allow the CUDA kernels associated with the generation of each token to be launched and executed as a single CUDA Graph, which is giving around 10-15% overall speedup on our GPUs for the llama2 7B cases I have tested so far. More details are below. Could someone please add me to the project (username agray3) and I'll push the branch? It will require a bit more testing and tweaking before it is ready for a PR.

For an introduction to CUDA Graphs, see the blog I wrote a few years ago: https://developer.nvidia.com/blog/cuda-graphs/ In llama.cpp, I use the stream capture functionality that is introduced in the blog, which allows the patch to be very non-intrusive - it is isolated within ggml_backend_cuda_graph_compute in ggml-cuda.cu (except a utility function to get a function pointer from ggml-cuda/cpy.cu).

For example, inference for llama-2-7b.Q4_K_M on H100-PCIe (with --n-gpu-layers 100 -n 128) the performance goes from 143.35 to 163.83 tokens per second (14% speedup).

Here are some screenshots from NSight Systems which show why using CUDA graphs is of benefit.

Here is the execution of a token using the current llama.cpp: nograph

Each CUDA kernel is launched and executed separately. The entries highlighed shows the launch API call associated with a specific kernel.

Zoomed in: nograph_zoom The main problem is the gaps between the kernels. (Note in this case these gaps are actually mostly due to GPU-side launch overheads rather than CPU API calls.)

With CUDA Graphs: graph

The whole token generation is launched by a single CUDA graph. Zoomed in: graph_zoom The use of CUDA graphs has allowed the kernels to be much more tightly packed.

The execution of the graph itself is actually around 40% faster with CUDA graphs. This overall speedup is lower (14%) largely due to overheads associated with creating and launching the graph, but there is scope to further reduce these overheads in the future.

agray3 avatar Apr 19 '24 09:04 agray3

Just sent you a collaborator invite

Edit: on second thought, I revoked the invite for the moment. I just noticed that your Github account is very new so I hope you understand me being cautious. You can open a PR without a collaborator access for the time being

ggerganov avatar Apr 19 '24 09:04 ggerganov

Thanks for responding so quickly Georgi. I fully understand you being catious. Over the last several years I have been one of the main GPU contributors to the GROMACS open source project, but it’s over on Gitlab, see https://gitlab.com/gromacs/gromacs/-/merge_requests?scope=all&state=merged&author_username=alangray3. I’ll also drop you an email from my NVIDIA account attaching the patch, when I’m back at my desk, out with my dog at the moment :) ( I think I need collaborator access to push a branch and hence create a PR, or please correct me if you think it is possible without that).

agray3 avatar Apr 19 '24 10:04 agray3

That's really cool @agray3, welcome here <3 I am excited to see your first PR!

phymbert avatar Apr 19 '24 10:04 phymbert

Hi @agray3, you can fork the project in github, and push the branch to your fork. Then you will have the option to open a PR from the changes in your fork.

slaren avatar Apr 19 '24 10:04 slaren

Thanks for the warm welcome! I've now created a PR: https://github.com/ggerganov/llama.cpp/pull/6766 but I've labeled it DRAFT since as mentioned above I think this will need some more testing across different models, CUDA versions, etc before it is merged.

agray3 avatar Apr 19 '24 12:04 agray3

While we're on it, I'm really not an expert, but I looked at the source code for the llama_decode_internal method, and I noticed:

Seems like the decoder isn't really that optimized for incremental token decoding (its most common use case), and graph reuse, like most people would expect. It was kind of hard to believe at first until I found this discussion.

Since you're measuring tokens per seconds, just know that these extra costs are included for each token decoded. I'm not sure how it significantly it impacts it. Just wanted to mention.

Sorry if this isn't 100% relevant.

rotemdan avatar Apr 19 '24 12:04 rotemdan

While we're on it, I'm really not be an expert, but I looked at the source code for the llama_decode_internal method, and I noticed:

Seems like the decoder isn't really that optimized for incremental token decoding (its most common use case), and graph reuse, like most people would expect. It was kind of hard to believe at first until I found this discussion.

Since you're measuring tokens per seconds, just know that these extra costs are included for each token decoded. I'm not sure how it significantly it impacts it. Just wanted to mention.

Sorry if this isn't 100% relevant.

Yes, indeed this is relevant, and these re-generation costs are definitely significant. If we can further optimize the code to avoid re-generation of the graph at each token, the token per second performance will further improve.

agray3 avatar Apr 19 '24 12:04 agray3

Regarding ggml graph creation overhead: I think the impact of this will heavily depend on the baseline t/s you can get with a given model. Presumably you're investigating the impact with a 7b q4_k_m model on an H100 but I don't think many people will run models with a configuration like that outside of benchmarking. And for a larger model with a lower baseline t/s the relative impact of graph creation will be lower.

JohannesGaessler avatar Apr 19 '24 13:04 JohannesGaessler

Regarding ggml graph creation overhead: I think the impact of this will heavily depend on the baseline t/s you can get with a given model. Presumably you're investigating the impact with a 7b q4_k_m model on an H100 but I don't think many people will run models with a configuration like that outside of benchmarking. And for a larger model with a lower baseline t/s the relative impact of graph creation will be lower.

Agreed, although with future hardware (including consumer GPUs) we would expect the GPU kernels to continue to speed up more than CPU parts. Also I suspect this may also be more of an issue for multi-GPU runs (even for larger models).

agray3 avatar Apr 19 '24 13:04 agray3

Related to the above I'm studying the CPU code executed between device token generations, and discovered that CPU memset calls in the schedule reset are significant and on the critical path. I've made PR #6933 which allows these to be performed earlier, while the CPU is waiting for the previous token to be generated on the device. On A100-PCIe I am measuring 3% overall speedup for this change, testing main for LLama2 7B with batch size 1. @ggerganov @JohannesGaessler @slaren could you possibly sanity check this (very small) change (currently set as a Draft PR)?

agray3 avatar Apr 26 '24 16:04 agray3