llama.cpp icon indicating copy to clipboard operation
llama.cpp copied to clipboard

Multi GPU support, CUDA refactor, CUDA scratch buffer

Open JohannesGaessler opened this issue 1 year ago • 15 comments

This PR supersedes https://github.com/ggerganov/llama.cpp/pull/1607 ; that one has issues with a decline in prompt processing speed. This PR adds more things on top and should end up being at least as fast as master across the board.

Compared to master this PR includes:

  • Multi GPU support for CUDA: there is one main GPU that does the calculations for all tensors except matrix multiplications which take up most of the runtime. Matrix multiplications are split across GPUs and done in parallel. The split of tensors across GPUs is by default done by VRAM but can be explicitly set with --tensor-split.
  • GGML_BACKEND_CUDA and GGML_BACKEND_CL unified to GGML_BACKEND_GPU. New backend GGML_BACKEND_GPU_SPLIT for tensors split across all GPUs.
  • ggml_cuda_compute_forward: new function in ggml-cuda.cu that serves as the entry point for CUDA code. ggml_task_type and ggml_compute_params are moved to ggml.h to make them usable.
  • ggml_cuda_op: new function that manages data between the host and the devices as well as iteration over that data. The actual kernels then don't have to do this.
  • CUDA kernels for add, SiLU, RMS norm, cpy, reshape, view, transpose, and RoPE. These kernels are not performance critical but serve as glue between matrix multiplication layers so that less data needs to be copied between host and device.
  • CUDA scratch buffers: memory is allocated for CUDA kernels to store temporary results. Currently a single circular 512 MB buffer is used (vs. 2x 512 MB on the CPU) and this seems to be large enough for 33b.
  • Edit: This PR also has a fix for CUDA buffer memory leaks when destroying and recreating LLaMa models.

Currently this PR has the following issues:

  • ~~Performance is bad unless all layers are on the GPU.~~ Fixed.
  • ~~The main device cannot be set by the user.~~ Added a CLI argument.
  • ~~Performance regression for GPUs with small amounts of VRAM.~~
  • ~~The way tensors get assigned their scratch buffers via ggml_cuda_assign_buffers is currently very ugly. @ggerganov I have the following proposal: extend ggml_context by a property default_backend that can be set by the user. Then, when the user creates a new ggml_tensor using that context the backend for that tensor is set to the one previously specified.~~ Done; tell me if the design is a problem.

JohannesGaessler avatar Jun 05 '23 18:06 JohannesGaessler

These are the performance numbers that I currently get when not affected by the known issues. PP = prompt processing, TG = token generation.

GPU Test Model ms/t master ms/t PR
RTX 3090 PP 7b 3.21 2.26
RTX 3090 PP 13b 4.64 3.27
RTX 3090 PP 33b 9.33 6.55
RTX 3090 TG 7b 21.53 21.78
RTX 3090 TG 13b 32.85 33.02
RTX 3090 TG 33b 68.74 67.78
GTX 1070 PP 7b 6.17 4.93
GTX 1070 TG 7b 71.40 71.19

JohannesGaessler avatar Jun 05 '23 19:06 JohannesGaessler

There's an issue when processing long prompts.

JohannesGaessler avatar Jun 05 '23 22:06 JohannesGaessler

Do you consider the new approach @ggerganov proposed by Metal support? In theory, it should be faster as we will try to put everything on GPU and runs from there.

howard0su avatar Jun 06 '23 01:06 howard0su

Originally I was going to make this PR once I can put everything on the GPU. However, my multi GPU PR is causing a performance regression for prompt processing so I'm making an earlier PR with the tensors that are already working; with those prompt processing is faster overall.

JohannesGaessler avatar Jun 06 '23 06:06 JohannesGaessler

There is also the issue that the metal implementation seems to keep copies of the data both in RAM and VRAM. This is what I did initially and I decidedly do not want to do that because for many users that will exceed their hardware limits.

JohannesGaessler avatar Jun 06 '23 06:06 JohannesGaessler

What specifically do you mean by "partial offloading"? The use of GPU layers or that only part of a layer is on the GPU?

JohannesGaessler avatar Jun 06 '23 08:06 JohannesGaessler

What specifically do you mean by "partial offloading"?

Part of the tensors being processed on the GPU and part of the tensors on the CPU. - i.e. the -ngl N option

ggerganov avatar Jun 06 '23 08:06 ggerganov

Let me be frank: that feature is the whole reason I'm putting in this much effort in the first place. My goal is to reduce the hardware requirements for running llama.cpp and to enable as many regular people as possible to run 33b/65b. I believe that such an increase in efficiency will greatly increase the demand for models at those sizes and thus more people will e.g. make finetunes. And adding partial GPU acceleration (or multi GPU support) in after the fact when I already have a large number of GPU accelerated tensors will be much more difficult compared to building it in from the beginning.

JohannesGaessler avatar Jun 06 '23 09:06 JohannesGaessler

I did more performance tests. The system with the RTX 3090 has a Ryzen 3700X and 3200 MHz RAM, the one with the GTX 1070 and the GTX 1050 ti has an i5-4570S and 1600 MHz RAM. PP = prompt processing evaluated with ./perplexity, TG = token generation.

GPU Test Model Batch size GPU layers ms/t master ms/t PR
RTX 3090 PP 7b 512 33/33 3.21 2.26
RTX 3090 PP 13b 512 41/41 4.64 3.27
RTX 3090 PP 33b 512 61/61 9.33 6.55
RTX 3090 TG 7b 512 33/33 21.53 21.78
RTX 3090 TG 13b 512 41/41 32.85 33.02
RTX 3090 TG 33b 512 61/61 68.74 67.78
GTX 1070 PP 7b 512 33/33 6.17 4.93
GTX 1070 PP 13b 512 37/34 10.56 8.99
GTX 1070 TG 7b 512 33/33 86.44 85.53
GTX 1070 TG 13b 512 37/34 176.51 164.51
GTX 1050 ti PP 7b 512 25/21 19.32 17.44
GTX 1050 ti PP 13b 512 15/12 35.99 36.43
GTX 1050 ti TG 7b 512 25/21 197.16 192.98
GTX 1050 ti TG 13b 512 15/12 395.17 386.20

Due to the additional VRAM used for the scratch buffer fewer layers fit on the GPU. Prompt processing is faster, there is no significant performance regression with a batch size of 512. The batch size can be lowered to reduce VRAM usage at the cost of prompt processing speed. In terms of performance I consider this good enough for a merge.

JohannesGaessler avatar Jun 06 '23 09:06 JohannesGaessler

Let me be frank: that feature is the whole reason I'm putting in this much effort in the first place.

That's fine, as long as you can avoid the extra logic in ggml that I've pointed in the comment. I just thought it would be much easier for you to implement it without trying to support -ngl.

Will take a look at the proposed alternative a bit later

ggerganov avatar Jun 06 '23 09:06 ggerganov

I think that the "partial offloading" feature is complicating the implementation a lot for little to no benefit

vs

Let me be frank: that feature is the whole reason I'm putting in this much effort in the first place.

For me and, as far as I can see, for many users, this is exactly the killer feature. On my RTX 3060, I can only fit half of a 33b model, but the gain is enormous - for the first time ever, 33b runs at an acceptable speed.

I can't fit a whole 65B model in my ram alongside the OS, but I can with the RTX3060. This is very slow, but still a win for the experiments with AutoGPT. I don't need to watch interactively. So, the benefit is HUGE. For many users, it makes things possible that would not be possible without it.

So please do not cut it out.

maddes8cht avatar Jun 06 '23 10:06 maddes8cht

I moved the offloading logic to llama.cpp and ggml-cuda.cu. Right now the llama.cpp part looks like this: after a tensor that modifies the data is created the user code calls the function ggml_cuda_assign_buffers to set the correct backend for the tensor output and the pointers on the GPU scratch buffer. The function also checks whether src0 is a tensor has a GGML_OP that does not modify data such as GGML_OP_RESHAPE and calls ggml_cuda_assign_buffers recursively until it reaches a tensor that does modify data (this makes the logic and debugging of ggml_cuda_compute_forward much easier and more similar to the logic of ggml_compute_forward).

Since this has been a repeated point of design discussion: perhaps guidelines for acceptable ggml.c/ggml.h modifications should be added to the README?

JohannesGaessler avatar Jun 06 '23 10:06 JohannesGaessler

I am currently allocating 1 MB per batch size for the circular VRAM scratch buffer. This seems to be enough even for 65b. It's possible that the size requirement will go up once all tensors are on the GPU but for now it looks like there is potential to reduce the memory footprint of the CPU code as well.

JohannesGaessler avatar Jun 06 '23 12:06 JohannesGaessler

This PR is now feature complete and seems to work correctly. If the design is acceptable I will rebase and do final testing.

JohannesGaessler avatar Jun 06 '23 16:06 JohannesGaessler

@JohannesGaessler

This is looks good. Let's rebase and merge

P.S. Haven't tested it - let me know if you prefer that I do some tests as well

ggerganov avatar Jun 06 '23 17:06 ggerganov

I've pushed a rebased version. The new quantization formats seem to be working correctly in combination with multi GPU. The CI will take some time anyways so I will quickly do some more testing. My own code should be producing correct results though and there is nothing that I would expect to cause a performance regression relative to what I pushed before.

JohannesGaessler avatar Jun 06 '23 18:06 JohannesGaessler

As far as I can tell everything is working correctly. Performance is good as long as I don't forget to disable debug options.

JohannesGaessler avatar Jun 06 '23 18:06 JohannesGaessler

I get Error: connect ECONNREFUSED 127.0.1.1:8080 when I try to run the server example but I get the same error on master so I'm assuming that it has nothing to do with what I did in this PR.

JohannesGaessler avatar Jun 06 '23 19:06 JohannesGaessler

Thank you for being patient with me.

JohannesGaessler avatar Jun 06 '23 19:06 JohannesGaessler

Terrific work, @JohannesGaessler. Been looking forward to this feature for weeks.

Thanks for sharing those performance tables, too. Out of curiosity I wanted to sort the data by performance increase to see which combination of model and GPU performed from best to worst.

GPU Test Model GPU layers Perf Delta
RTX 3090 PP 33b 61/61 29.80%
RTX 3090 PP 7b 33/33 29.60%
RTX 3090 PP 13b 41/41 29.53%
GTX 1070 PP 7b 33/33 20.10%
GTX 1070 PP 13b 37/34 14.87%
GTX 1050 ti PP 7b 25/21 9.73%
GTX 1070 TG 13b 37/34 6.80%
GTX 1050 ti TG 13b 15/12 2.27%
GTX 1050 ti TG 7b 25/21 2.12%
RTX 3090 TG 33b 61/61 1.40%
GTX 1070 TG 7b 33/33 1.05%
RTX 3090 TG 13b 41/41 -0.52%
RTX 3090 TG 7b 33/33 -1.16%
GTX 1050 ti PP 13b 15/12 -1.22%

curtisgray avatar Jun 07 '23 05:06 curtisgray

Can the multi-gpu support lower compute-capability cards? Specifically 5.2 for M40 24GB cards?

svanschalkwyk avatar Jun 08 '23 02:06 svanschalkwyk

I didn't investigate what the minimum compute capability is. The multi GPU code does work on 4x GTX Titan X though which have a compute capability of 5.2.

JohannesGaessler avatar Jun 08 '23 06:06 JohannesGaessler

Hey, first of all great PR but it resulted in a regression on my (HIP/AMD) system so I wonder if that's specific to HIP bindings, or if there's something specific to this patch that forces this behaviour. As you can see from trace below, the prompt eval time is something crazy, & in fact the GPU is stuck in being idle & reporting as 100% for the whole time so I reckon if the actual prompt evaluation for some reason reverts to CPU?

llama_print_timings:        load time = 380273.89 ms
llama_print_timings:      sample time =   180.26 ms /   223 runs   (    0.81 ms per token)
llama_print_timings: prompt eval time = 350277.48 ms /    24 tokens (14594.89 ms per token)
llama_print_timings:        eval time = 128163.59 ms /   222 runs   (  577.31 ms per token)
llama_print_timings:       total time = 508693.15 ms

tucnak avatar Jun 08 '23 12:06 tucnak

I don't see why the generation would be done on the CPU. I think the problem is rather that one of the operations that I used has good performance on NVIDIA but very bad performance on AMD. I'll try to catch up with the HIP PR and to offer some assistance.

JohannesGaessler avatar Jun 08 '23 12:06 JohannesGaessler

Thanks you for the timely response! I know very little about how GPUs are programmed but usually when the card is in idle (contrary to what's reported by hwmgr) with no voltage/power fluctuations whatsoever, it doesn't do anything, & all the while all CPU threads are "working". As you can see, once prompt eval is complete, the generation isn't impacted. 577 ms/tok is normal for 65b model in 3-bit mode which is what I've been running in this case but the issue persists for smaller models in 8-bit, 4-bit v3, and across various warp sizes, too.

Is there any debugging info that I may provide that would be of potential use to you?

tucnak avatar Jun 08 '23 13:06 tucnak

Much appreciated Johannes. I'll test it on two M40s this weekend and let you know.

svanschalkwyk avatar Jun 09 '23 02:06 svanschalkwyk

How do I use the multi-gpu? I didn't see an example in the README.md

ehartford avatar Jun 16 '23 18:06 ehartford

https://github.com/ggerganov/llama.cpp/tree/master/examples/main#additional-options

JohannesGaessler avatar Jun 16 '23 18:06 JohannesGaessler

thanks!

ehartford avatar Jun 16 '23 19:06 ehartford

Sorry, it seems I made a mistake at some point and didn't catch it during review. This is not intended.

JohannesGaessler avatar Jun 18 '23 12:06 JohannesGaessler