llama.cpp icon indicating copy to clipboard operation
llama.cpp copied to clipboard

Fine tune MUL_MAT, new threading (spin+wait/notify), speedup q_f32 BLAS by splitting COMPUTE stage

Open mqy opened this issue 2 years ago • 57 comments

Introduction

MUL_MAT take most of the compute time (about 95%). So to speed up llama, we have to focus on MUL_MAT. BLAS, as one of the fastest MUL_MAT solution on CPU, typically efficient at computing large matrix multiplication and tends to be very slow when run parallel in multi OS threads. Accelerate is the native BLAS implementation on macOS, which has the problems exactly as said. OpenBLAS or BLIS are a bit slower than Accelerate, the authors claim that they support multi-threads, but I did not test that. So I assume for the big matrix sizes in llama, multi threaded BLAS does not run faster than single thread.

We have three kinds of MUL_MAT to compute:

  1. mul_mat_f32: both src0 and src1 are F32.
  2. mul_mat_f16_f32: src0 is F16 and src1 is F32.
  3. mul_mat_q_f32: src0 is qauntizied (Q4_0, Q4_1, ...), and src1 is F32.

For every kind of MUL_MAT, we have pure CPU solution which has optional INIT stage and COMPUTE stage. And optional solutions: CUDA/CL that run in GPU, and BLAS that run in CPU.

  1. mul_mat_f32: has only one stage: COMPUTE.
    • The pure CPU with multi-threads
    • BLAS, CUDA and CL with single thread
  2. mul_mat_f16_f32:
    • The pure CPU has two stages: INIT with single threads, COMPUTE with multi-threads.
    • BLAS, CUDA and CL with single thread
  3. mul_mat_q_f32: same as mul_mat_f16_f32, but the de-quantization time is significant.

As of BLAS, there are three known problems to solve:

  1. spin only threading. While spin has been the simplest and perhaps the fastest solution, the community has been seeking some kind of practical threading infrastructure that can compensate the busy spinning at certain situations for long.
  2. single thread BLAS. This is because that: The typical mul_mat time when N/K >= 4096 ranges from several ms to hundreds ms. Given n_threads > 1, when run BLAS in main thread, worker threads has nothing to do thus keep spinning. The spinning overhead is not acceptable. Given M/N/K, n_threads (and even src0 type), due to the diverse of matrix dimensions and hardware/software stacks, we are not sure which of the solutions is the fastest. At present, master branch applies this rule: run CUDA/CL/BLAS in single OS thread when both src0 and src1 are continuous and M >=32 && N >=32 && K >= 32. As of llama model, this rule almost equals to M >= 32 && N >= 4096 && K >= 4096.
  3. For some N/K, de-quantization time may exceeds mul_mat time when M < 128. This range covers the token size of typical daily conversations. So, we'd better separate de-quantization out of the for loops, thus we can run de-quantization in multi-threads.

Solutions

This PR tries to solve the above problems, they are tightly coupled together. So it's hard to just solve one without touching others.

1. A new threading infrastructure that supports spin + wait/notify

Typical usages are:

  • when compute a task stage, main threads knows that this stage can only run by it's self, and the task stage is configured as idle wait, it issues a wait_now command, workers get this command almost at once, then go wait.
  • workers can be configured with wait_on_task_done: that means we can look ahead a few future task stages to see if there are no immediate multi-thread needs. If no, then tell workers go waiting after finishing task. The optimization benefits energy saving, but is hard to implement correctly and efficiently. In addition to mutex, I have to use spin lock.
  • Also, when compute a task stage, if main threads knows current task stage needs more workers, it executes a syscall to wake up all workers. I had ever implemented a threading framework that can await or wakeup given number of workers. I finally discarded that because I did not find evidence to use only part of workers.

2. A way to configure how to run task stage.

I want to explicitly define: which part of code to run, single thread or multi-thread, workers should go idle wait or not. This is not new but introduced the idle wait and make the configure more explicit. With this we can run bench at will, this unlock us from current implicit#if defined(xxx), and allow us to build with all kinds solutions. I formally defined task profiles for the three kinds of mul_mat. This took not little codes, but is very important for the whole solution.

3. A flexible tune(bench) tool to generate bench data

This tool has the following features/benefits:

  • Supports all llama models and typical matrix sizes and types (attention layer, feed-forward layer, RoPE)
  • Supports all types (F32, F16, all of the Qx_x). NOTE F32 and F16 are disabled as workaround to avoid a unfixed bug.
  • Able to write to/read from file. So the result can be generated ahead of time, and be loaded into memory later.
  • The data file is designed as self-contained, including model, type, backend, all 6 typical shapes, every shape contains their task profiles and per task stage execution time for every task profile.
  • Able to estimate execution time for any M and n_threads, provide corresponding APIs for GGML.
  • Analyze bench data for n_threads. The output is CSV blocks, thus can be easily visualized.
  • Should cover typical M range. I had ever generated M with a constant start value, increase with constant step (for example, from 16, step in 8). Now I generate M with (for n in [0, 10] M := 2^n), this balance the two fundamental needs: (1) M range should reasonable large (2) should assign more M(s) for M <=32 because I guess this is the typical conversation token size that will be executed frequently and this M range is sensitive to profile selecting as of multi-threading.
  • Should run as fast as possible. It takes about 75 seconds on my device to bench 7B/Q4_0/Accelerate with 10 Ms range from 1 up to 512 in 3 passes 1 thread, while one pass bench takes about 35 seconds 1 thread, with 4 threads 1 pass and max-M 128 takes about 13s. Current speed is not good enough in case of running bench at program startup.

4. Adapt llama and ggml to schedule with bench

After the bench data was loaded into program, when do graph computing, we can at first match shape by given N/K, then estimate time for every profile that this shape supports, finally select the fastest profile. Since in practice, we only bench for limited M (10s or so) , we have to leverage some magic to estimate time for any M. Due the the near linear nature of M-time curve, I use interpolate. This is not very cool, but is the best affordable way I can think. Non-continuous matrices are not suitable to run in BLAS, so they will be scheduled to the pure CPU profile. If both src0 and src1 of matrix are continuous, but we do not have bench loaded or for some unknown reasons or bugs that we can not find corresponding shape for given N/K, or unable to estimate, we fallback to the traditional logic: M >= 32 && N >=32 && K >= 32 -- this is totally unfortunate because estimating bias around 32 is highly sensitive to performance. You will see this in the following section.

5. Split single thread BLAS

I separated de-quantization with de-quantization + mul_mat from the for loops. Thus I can create the third task profile for the q_f32's use BLAS solution: run de-quantization in INIT stage with multi-threads, run mul_mat with BLAS and single thread, let workers idle wait.

Results

Due to the nature of predicating, it's a bit hard for me to bench end to end. I wrote a bench tool named prompt.sh to ask llama questions like this: 0+0=0.1+1=1.2+2=. Although in this way it is easy to construct prompt at almost any approximate size, this kind of questions are likely take llama too much time to think, thus result in unusual bench time that may be longer than those normal questions. I have to say that I don't know how to efficiently and correctly run the end-to-end bench at all. Anyway, I did run the examples/chat.sh with 4 threads for many times. Often observed the prompt time decreases about 35%, sometimes over 40%, comparing to master.

So, let me explain in more strict but perhaps easier understood way with a bunch of images. First of all let's remember several tokens that will be used to identify the task stages for the three q_f32 profiles.

  • #0_0_nth=1 : profile 0, stage 0, n_threads = 1
  • #0_1_nth=1 : profile 0, stage 1, n_threads = 1
  • #0___nth=1 : profile 0, total, n_threads = 1
  • #1_1_nth=1 : profile 1, stage 1, n_threads = 1
  • #1___nth=1 : profile 1, total, n_threads = 1
  • #2_0_nth=1 : profile 2, stage 0, n_threads = 1
  • #2___nth=1 : profile 2, total, n_threads = 1
  • #0___nth=2 : profile 0, total, n_threads = 2
  • ...
  • #2___nth=6 : profile 2, total, n_threads = 6

Where stage 0 is the INIT stage and stage 1 is the COMPUTE stage. The values of n_threads are typical because:

  • apart from 1, we usually use even n_threads.
  • personal computers often do not have that many physical cores, 6 n_threads is OK.
  • suppose the single thread time is t1, when we increase n_threads, we will get, t2=0.5t1 for n_thread=2, t4=0.25t0 for n_threads=4, t3=0.16*t0 for n_threads=6. The 0.16 means -84%, this is a pretty good speed up, I think.
  • Too many threads causes heavy spin + wait/notify burden. When the ROI (speedup rate v.s. energy/heat) decreases to certain value, increasing n_threads will help little or even hurt.

All data in the following images are created from llama 7B. I will not show you all models because that's too lengthy and I can only run 7B/13B. Instead I'll try Q4_0, Q5_0 and Q8_0 because they are enough for us to catch the points.

I ran bench/analyze on my MacBook pro 2018 with 32 GB 2400 MHz DDR4 memory, 2.6 GHz 6-Core Intel Core i7-8850H @2.60GHz.

The data are all plotted in 2-D lines, where the x-axis is M, and the y-axis is per-thread execution time with unit of ms.

4096x4096, Q4_0

The M >=32 rule and bias

The next diagram shows the execution time of profile-0 at stage-0 and stage-1. The axis scale is logarithmic. The stage-0 time is very fast, and is negligible comparing to that of stage-1. We can anticipate that:

  • the overall compute time is almost same as that of stage-1
  • when run with multi threads n, the per-thread execution time should be 1/n of the single thread.
pic-1

The next diagram shows the execution time of profile-1 at stage-1 (BLAS). The axis scale is logarithmic. It's almost near constant when M <= 64, otherwise the Δt/ΔM goes up more and more finally the time becomes linear to M. I guess the reason why the time increases so much when M>64 is because 4096x4096x64 is the total 1 billon number of float32 to allocate at 32GiB memory, this is identical to my device memory. When it exceeds max memory, the OS has to compress memory or use swap, this would greatly hurt performance.

pic-2

The next picture is used to explain bias ranges in current master code. Let's firstly find the points that the blue line intersects with other lines. The blue line represents the overall execution time for profile-2, whereas other 4 lines represent the overall execution time for profile-0 at that n_threads. Every line for profile-0 intersects with the line for profile-2 at some point. So given n_threads and M, we can easily determine the fastest profile (line) by simply having a glance at the intersection point. For those Ms not in x-axis, we can easily estimate the corresponding time.

Now let's focus on the vertical line at M=32. Given n_threads, we can find the corresponding line for profile-0 and profile-2. Let's recall the default profile selecting policy in master code: M >=32 && N >= 32 && K >=32. This means: for NxK= 4096x4096, when M <32 we follow the line for profile-0, otherwise follow the line for profile-2.

This is ideal when the two line intersect at M=32, otherwise the estimation bias will show up for those Ms between the intersection point and 32. We can see that for any line of profile-0, the bias goes up from 0 (at intersection point) to |t0-t1| (at M=32), where t0 is the profile-0 time and t2 is the profile-2 time. The max bias is so large that may reach up to 30% for n_threads=1 and 2, and up to 60% for n_thread=4 or 6. Of course, with the increasing of n_threads, the spinning and memory contention or cache miss would cause certain performance degradation, finally the per-thread average time would not reach that ideal (small) value.

As I had said before, M is the token size. Since white spaces and stems are also be counted in the token size, for any typical question or statement, the corresponding prompt token size should is likely get closes to 32.

Anyway, nowadays personal computers tends to have big memory and fast CPUs, thus the bias may not be noticed or tolerable.

pic-3

Parallel de-quantizing

The next two pictures shows the trend of de-quantization time at INIT stage as a percentage of the whole execution time. In theory, de-quantization (INIT) time is determined by N/K only, so it can be seen as a constant. But BLAS time increases after M>64.

The important thing to learn from this plotting is: the INIT time is near or bigger than the COMPUTE time at pretty large M range: up to 128! It is about 1/3 of the overall time even at M=256. So if we run INIT with multi-threads, we can get far better performance than single thread. Ideally, we can speed up over 50% when M <= 64, and 30% ~ 40 % when M between 64 and 128.

pic-4 pic-5

Finally I show you the multi-threaded plotting, for simplicity purpose I just show nth=1 and nth=4. From this picture we can see that: M at intersection point increases with n_threads. I've seen that there is no intersection point at all when n_threads=8: that means the pure CPU solution always run faster than BLAS solution even if both run with multi-threads.

With fine tuning, given model, type, M,N,K and n_threads, we will able to select the correct profile.

pic-7

Other images

I will not explain them. The important reason that I list these images is: show similarity and minor differences.

pic-14

How to evaluate

Build with make or CMake

Make sure one of the BLAS vendor is enabled and compiled into program.

#Accelerate: make clean; LLAMA_NO_ACCELERATE=  make
#OpenBLAS:   make clean; LLAMA_NO_ACCELERATE=1 LLAMA_OPENBLAS=1 make
#BLIS:       make clean; LLAMA_NO_ACCELERATE=1 LLAMA_BLIS=1 make
#CLBLAST     make clean; LLAMA_CLBLAST=1 make

#CUDA is supported, but not well tested, may not run at all.

Evaluate:

NOTE when GPU offloading is enabled (-ngl > 0), mul_mat tuning is disabled atomatically.

# help
./mulmat-tune -h

# tune, use default config, 7B, Q4_0, n_threads=4, ...
./mulmat-tune

#tune and run
./main ... --tune

# tune and save file, exit.
./main ... --tune --tune-file=<FILE>

# load and run:
./main ... --tune-file=<FILE>

./perplexity ... --tune

Have a look at examples/mulmat-tune/README.md for details

Conclusion

Software systems are complicated. It's hard to optimize when target platforms vary widely. I'm certain that the speed up to q_f32 would not become reality without the new threading infrastructure, task config profile and the mulmat tune tool. I'm happy that for so long time I finally able to show you the working codes. Enjoy!

@ggerganov @SlyEcho @0cc4m @JohannesGaessler @zenixls2 @slaren

EDITED on Jun 18

  • typos
  • hide 5 images
  • tune: sync with latest changes

EDITED ON Jun 26

I haven't updated this PR for a few days, because of the following reasons I think:

  • to support tuning, this PR introduced too many updates.
  • the threading implementation is ugly and full of tricks, not well-tested.
  • hard to test for Windows and CL/CUDA due to limited personal devices.
  • controversial design of task profiles: intrusive.
  • hard to merge even pieces of codes, tends to become trouble maker.
  • finally, in favor of https://github.com/ggerganov/ggml/issues/293

Great thanks to @KerfuffleV2 for help testing and all of you who took time on this PR.

I'm sorry @ggerganov this took you time to review, so I close this PR?

mqy avatar May 29 '23 00:05 mqy

CMakeFiles does not work, perhaps should move mulmat-tune.[c,h] to root dir.

mqy avatar May 29 '23 02:05 mqy

I was thinking recently that better threading would be nice to have.

Anyways, I didn't yet look at the PR in detail but I can already give you feedback regarding the way you represent your data to make it easier to understand:

  • Add units to the table: you can't tell at a glance what the numbers mean. Then you no longer need to go back and forth between the README and the image.
  • Label the plot axes: again, you cannot tell at a glance what the lines mean. There are benchmarks where lower is better and some where higher is better. Put the meaning of the axes directly in the image.
  • Markdown lets you create tables. That may be a little easier to use than an image.

Regarding the contents of the README: unless I'm misunderstanding something you are at one point talking about doing dequantization on the CPU and then doing the actual matrix multiplication on the GPU. This is not a viable approach. The weights are very large and become even larger when dequantized. Transferring that much data between CPU and GPU is very slow, slower than to just do everything on the CPU. My implementation only works because weights are stored in VRAM and thus don't need to be copied to the GPU.

JohannesGaessler avatar May 29 '23 05:05 JohannesGaessler

dequantization on the CPU and then doing the actual matrix multiplication on the GPU. This is not a viable approach.

We started out that way, at first cuBLAS was used without the custom kernels. It did work but obviously was much slower than it is now.

SlyEcho avatar May 29 '23 07:05 SlyEcho

CMakeFiles does not work, perhaps should move mulmat-tune.[c,h] to root dir.

I think so. It is seems to be another part of ggml, so I would rename them to ggml-tune.{c,h}

SlyEcho avatar May 29 '23 07:05 SlyEcho

CMakeFiles does not work, perhaps should move mulmat-tune.[c,h] to root dir.

I think so. It is seems to be another part of ggml, so I would rename them to ggml-tune.{c,h}

mqy avatar May 29 '23 09:05 mqy

I was thinking recently that better threading would be nice to have.

Anyways, I didn't yet look at the PR in detail but I can already give you feedback regarding the way you represent your data to make it easier to understand:

  • Add units to the table: you can't tell at a glance what the numbers mean. Then you no longer need to go back and forth between the README and the image.
  • Label the plot axes: again, you cannot tell at a glance what the lines mean. There are benchmarks where lower is better and some where higher is better. Put the meaning of the axes directly in the image.
  • Markdown lets you create tables. That may be a little easier to use than an image.

Regarding the contents of the README: unless I'm misunderstanding something you are at one point talking about doing dequantization on the CPU and then doing the actual matrix multiplication on the GPU. This is not a viable approach. The weights are very large and become even larger when dequantized. Transferring that much data between CPU and GPU is very slow, slower than to just do everything on the CPU. My implementation only works because weights are stored in VRAM and thus don't need to be copied to the GPU.

@JohannesGaessler feedbacks from you and others corrected me the misunderstandings. I managed to improve the README file a bit for now: fixed wrong terms, no longer use image, pasted some example results. I'll will keep updating it.

As of the term backend, similar to current enum ggml_backend, I was defined enum ggml_device for CPU and GPU before. Honest speaking, I always get confused with terms BLAS and GPU since then, sorry !

In this PR, bench result is tightly bond to specified implementation, so I named several backend vendors for validating the loaded bench file. Now I read the backend as "mixed implementation on top of hardware and software library spec", so I use it to control which part of code to run explicitly. I'm aware that your PR Cuda refactor, multi GPU support #1670 is ready to merge, congratulations!

Thanks!

mqy avatar May 29 '23 10:05 mqy

I'll try fix the CMake build. I'm not familiar with it, so will reference the configuration of ggml-opencl.

mqy avatar May 29 '23 11:05 mqy

I'll try fix the CMake build. I'm not familiar with it, so will reference the configuration of ggml-opencl.

Is it optional? Because ggml-opencl is optional.

Otherwise you can just add the files to the ggml library target.

SlyEcho avatar May 29 '23 12:05 SlyEcho

Is it optional? Because ggml-opencl is optional.

As far as I know, ggml-opencl is controlled by a compile flag namedLLAMA_OPENCL, while mulmat-tune doesn't have any compile flags at present. I'm not anticipating to define any compile flag for mulmat-tune, because both struct llama_context and struct ggml_cgraph were added the field struct ggml_mulmat_tune * mm_tune;. On llama init, if mulmat-tune.txt exists and was successfully loaded and validated, the mm_tune is set.

llama will pass mm_tune to every ggml_cgraph being created by it. In ggml_graph_compute_mul_mat_set_task_profile(), if cgraph->mm_tune is NULL, fallback to the M >= 32 && N >= 32 && K >= 32 logic.

I'm anticipating that in the future the choice of whether use mulmat tune or not will be controlled by two command line options: --mulmat-tune-file=FILE to load existing file, or --mulmat-tune to run bench at once and use the in-memory result.

I'm doubting the usefulness of--mulmat-tune because the bench time may looks too long. With bench parameters (model=7B, type=Q4_0, m_num = 10, n_pass=3), it takes about 75 seconds on my device, while 1-pass takes about 35 seconds. One of the possible fix is : given N/K (both > 0), do not run de-quantization for every M, I will try this later.

Thanks for the tip!

mqy avatar May 29 '23 13:05 mqy

@SlyEcho I just tried 3B, it's amazing fast than 7B! Thanks!

BTW, the mulmat-tune tool supports 3B now. I also added an env named LLAMA_MULMAT_TUNE_DATA_DIR for ease of switching between models/types. Below is how llama loads bench data file:

char * env_dir = getenv("LLAMA_MULMAT_TUNE_DATA_DIR");
if env_dir is NULL, then
    try open file "./mulmat-tune.txt" 
else
    try open file "$MODLEL.$FTYPE.txt". where:
    $MODLEL is the name of current in-memory model: 3B, 7B, ...
    $FTYPE is the name of ftype of model: Q4_0, Q4_1, ...

mqy avatar May 31 '23 01:05 mqy

@mqy

Thank you very much for this very detailed and comprehensive study!

I will take a very detailed look into this work in the following days and see what is the best way to integrate it into the project. Just keep in mind there are a couple of other PRs that I will prioritize before this one, so it might take a bit longer to respond. But in the meantime, will appreciate if other people look into this as well and share opinions.

I do thing that the threading mechanism in ggml can be significantly improved and from a very quick look through this PR, it look like it successfully tackles most of the important points

ggerganov avatar Jun 02 '23 08:06 ggerganov

@mqy

Thank you very much for this very detailed and comprehensive study!

I will take a very detailed look into this work in the following days and see what is the best way to integrate it into the project. Just keep in mind there are a couple of other PRs that I will prioritize before this one, so it might take a bit longer to respond. But in the meantime, will appreciate if other people look into this as well and share opinions.

I do thing that the threading mechanism in ggml can be significantly improved and from a very quick look through this PR, it look like it successfully tackles most of the important points

@ggerganov I dared to splitting threading because I really need to test it and dislike browsing and editing these codes here and there again and again. I'm afraid that you may not agree with this somehow, and had planned to ping you when this is ready today, but I'm stopped by misc Windows compiling errors, still try fixing them. These problems were caused by splitting threading codes out of ggm.c.

Good news: after splitting and testing, I fixed a race :D. This new threading module has the benefit of being easy to do testing and is no longer limited to the fixed thread/task runner. Looks threading is stable enough now, it's better to do more additional testing and benchmark as well.

There are still important things to do, I think, for example:

  1. rewrite part of task profile related codes -- because GGML_BACKEND_CL will be removed.
  2. support multi-threaded mulmat tune thus make it possible to bench on start optionally. Multi-threaded mulmat tune has several potential advantages:
    • speedup bench time to acceptable value so we can feel good to run online bench often.
    • avoid various mismatching problems that are hard to detect and resolve.
    • may fix gap between the ideal offline benching with actual runtime somehow.
  3. the bench tool has a shortcoming that only supports hybrid CPU + BLAS or CPU + GPU, perhaps this is ok at present, but may need to improved later.

I'm always expecting any kinds of feedbacks! Thank you very much!

mqy avatar Jun 02 '23 09:06 mqy

This isn't currently expected to work with cuBLAS, correct? When attempting to run compiled with cuBLAS I get:

invalid backend: 33
GGML_ASSERT: ggml-tune.c:691: false

Please let me know if any other information would be helpful.

KerfuffleV2 avatar Jun 04 '23 12:06 KerfuffleV2

This isn't currently expected to work with cuBLAS, correct? When attempting to run compiled with cuBLAS I get:

Sorry, this should be caused by a stupid bug, try again. Thanks!

I do not have CUDA available on macOS, and failed to run ClBlast due to low GPU memory. Really appreciate feedbacks especially platforms with CUDA and CL. Thanks!

mqy avatar Jun 04 '23 14:06 mqy

Glad to help!

I tried again. I don't know if it's important, but I just compile using make, so make clean && make LLAMA_CUBLAS=1

First, for reference perplexity running on master compiled with cuBLAS but no GPU offloading (q4_0 7b LLaMA model):

system_info: n_threads = 6 / 12 | AVX = 1 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | VSX = 0 | 
perplexity: calculating perplexity over 655 chunks, batch_size=512
perplexity: 3.70 seconds per pass - ETA 40 minutes
[1]4.4547,[2]4.9402,[3]5.8277,[4]6.4843

Offloading 30 layers: perplexity: 3.43 seconds per pass - ETA 37 minutes

Doesn't help a whole lot since my GPU is pretty old (GeForce 1060 6G VRAM) and calculating perplexity doesn't really care about offloading from what I can see.


Now with your branch, but no layers offloaded: perplexity: 60.94 seconds per pass - ETA 11 hours 5 minutes

It doesn't seem to be using cuBLAS (checking GPU utilization also doesn't show anything going on).

Trying to offload some layers with -ngl 10:

perplexity: calculating perplexity over 655 chunks, batch_size=512
[3]    55146 segmentation fault (core dumped)

It actually ran for about a minute, so I think it did the calculation (on CPU not GPU as above) then crashed at the end of the pass.

I'll try to run in gdb compiled with LLAMA_DEBUG enabled and but since it compiles without optimization I don't know how long that will take.

KerfuffleV2 avatar Jun 04 '23 14:06 KerfuffleV2

I got impatient and tried running with batch size 4:

Thread 1 "perplexity" received signal SIGSEGV, Segmentation fault.
0x00005555555632d8 in ggml_vec_mul_f32 (n=4096, z=0x7ffe48010000, x=0x7ffe48000000, y=0x7fffa6820400) at ggml.c:2073
2073    inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i]*y[i];   }
(gdb) bt
#0  0x00005555555632d8 in ggml_vec_mul_f32 (n=4096, z=0x7ffe48010000, x=0x7ffe48000000, y=0x7fffa6820400) at ggml.c:2073
#1  0x0000555555576074 in ggml_compute_forward_mul_f32 (params=0x7ffffffe3c70, src0=0x7ffe68069370, src1=0x7fffa7213340, dst=0x7ffe68069490) at ggml.c:8156
#2  0x000055555557632e in ggml_compute_forward_mul (params=0x7ffffffe3c70, src0=0x7ffe68069370, src1=0x7fffa7213340, dst=0x7ffe68069490) at ggml.c:8194
#3  0x000055555558721f in ggml_compute_forward (params=0x7ffffffe3c70, tensor=0x7ffe68069490) at ggml.c:12885
#4  0x00005555555dab19 in ggml_threading_compute_tensor (threading_ctx=0x55556a9dd0a0, node=0x7ffe68069490, work=0x7ffe6806f5d0) at ggml-threading.c:427
#5  0x000055555558ba66 in ggml_graph_compute (ctx=0x5555556b2b48 <g_state+200>, cgraph=0x7ffffffe40d0) at ggml.c:14351
#6  0x00005555555954db in llama_eval_internal (lctx=..., tokens=0x55556aec13c0, n_tokens=4, n_past=0, n_threads=6) at llama.cpp:1431
#7  0x000055555559beeb in llama_eval (ctx=0x55556ab0a210, tokens=0x55556aec13c0, n_tokens=4, n_past=0, n_threads=6) at llama.cpp:2989
#8  0x000055555555b245 in perplexity (ctx=0x55556ab0a210, params=...) at examples/perplexity/perplexity.cpp:62
#9  0x000055555555ba0e in main (argc=12, argv=0x7fffffffd958) at examples/perplexity/perplexity.cpp:164
(gdb) print *z
$1 = 0.0190265197
(gdb) print *y
$2 = 0
(gdb) up
#1  0x0000555555576074 in ggml_compute_forward_mul_f32 (params=0x7ffffffe3c70, src0=0x7ffe68069370, src1=0x7fffa7213340, dst=0x7ffe68069490) at ggml.c:8156
8156                ggml_vec_mul_f32(ne00, dst_ptr, src0_ptr, src1_ptr);
(gdb) up
#2  0x000055555557632e in ggml_compute_forward_mul (params=0x7ffffffe3c70, src0=0x7ffe68069370, src1=0x7fffa7213340, dst=0x7ffe68069490) at ggml.c:8194
8194                    ggml_compute_forward_mul_f32(params, src0, src1, dst);
(gdb) up
#3  0x000055555558721f in ggml_compute_forward (params=0x7ffffffe3c70, tensor=0x7ffe68069490) at ggml.c:12885
12885                   ggml_compute_forward_mul(params, tensor->src0, tensor->src1, tensor);
(gdb) up
#4  0x00005555555dab19 in ggml_threading_compute_tensor (threading_ctx=0x55556a9dd0a0, node=0x7ffe68069490, work=0x7ffe6806f5d0) at ggml-threading.c:427
427                 state_shared->task_stage_runner(&params, node);
(gdb) up
#5  0x000055555558ba66 in ggml_graph_compute (ctx=0x5555556b2b48 <g_state+200>, cgraph=0x7ffffffe40d0) at ggml.c:14351
14351           ggml_threading_compute_tensor(thrd_ctx, node, cgraph->work);
(gdb) up
#6  0x00005555555954db in llama_eval_internal (lctx=..., tokens=0x55556aec13c0, n_tokens=4, n_past=0, n_threads=6) at llama.cpp:1431
1431        ggml_graph_compute       (ctx0, &gf);

Valgrind:

system_info: n_threads = 6 / 12 | AVX = 1 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | VSX = 0 | 
perplexity: calculating perplexity over 655 chunks, batch_size=4
==57615== Conditional jump or move depends on uninitialised value(s)
==57615==    at 0x13EC9A: ggml_graph_compute (ggml.c:14031)
==57615==    by 0x1494DA: llama_eval_internal(llama_context&, int const*, int, int, int) (llama.cpp:1431)
==57615==    by 0x14FEEA: llama_eval (llama.cpp:2989)
==57615==    by 0x10F244: perplexity(llama_context*, gpt_params const&) (perplexity.cpp:62)
==57615==    by 0x10FA0D: main (perplexity.cpp:164)
==57615== 
==57615== Conditional jump or move depends on uninitialised value(s)
==57615==    at 0x13E74D: ggml_graph_compute_mul_mat_set_tensor_task_profile (ggml.c:13953)
==57615==    by 0x13EDC6: ggml_graph_compute (ggml.c:14057)
==57615==    by 0x1494DA: llama_eval_internal(llama_context&, int const*, int, int, int) (llama.cpp:1431)
==57615==    by 0x14FEEA: llama_eval (llama.cpp:2989)
==57615==    by 0x10F244: perplexity(llama_context*, gpt_params const&) (perplexity.cpp:62)
==57615==    by 0x10FA0D: main (perplexity.cpp:164)
==57615== 
==57615== Use of uninitialised value of size 8
==57615==    at 0x18D4F0: ggml_mulmat_tune_get_shape (ggml-tune.c:335)
==57615==    by 0x13E83B: ggml_graph_compute_mul_mat_set_tensor_task_profile (ggml.c:13961)
==57615==    by 0x13EDC6: ggml_graph_compute (ggml.c:14057)
==57615==    by 0x1494DA: llama_eval_internal(llama_context&, int const*, int, int, int) (llama.cpp:1431)
==57615==    by 0x14FEEA: llama_eval (llama.cpp:2989)
==57615==    by 0x10F244: perplexity(llama_context*, gpt_params const&) (perplexity.cpp:62)
==57615==    by 0x10FA0D: main (perplexity.cpp:164)
==57615== 
==57615== Invalid read of size 4
==57615==    at 0x18D4F0: ggml_mulmat_tune_get_shape (ggml-tune.c:335)
==57615==    by 0x13E83B: ggml_graph_compute_mul_mat_set_tensor_task_profile (ggml.c:13961)
==57615==    by 0x13EDC6: ggml_graph_compute (ggml.c:14057)
==57615==    by 0x1494DA: llama_eval_internal(llama_context&, int const*, int, int, int) (llama.cpp:1431)
==57615==    by 0x14FEEA: llama_eval (llama.cpp:2989)
==57615==    by 0x10F244: perplexity(llama_context*, gpt_params const&) (perplexity.cpp:62)
==57615==    by 0x10FA0D: main (perplexity.cpp:164)
==57615==  Address 0xf0f0000034170038 is not stack'd, malloc'd or (recently) free'd
==57615== 
==57615== 
==57615== Process terminating with default action of signal 11 (SIGSEGV)
==57615==  General Protection Fault
==57615==    at 0x18D4F0: ggml_mulmat_tune_get_shape (ggml-tune.c:335)
==57615==    by 0x13E83B: ggml_graph_compute_mul_mat_set_tensor_task_profile (ggml.c:13961)
==57615==    by 0x13EDC6: ggml_graph_compute (ggml.c:14057)
==57615==    by 0x1494DA: llama_eval_internal(llama_context&, int const*, int, int, int) (llama.cpp:1431)
==57615==    by 0x14FEEA: llama_eval (llama.cpp:2989)
==57615==    by 0x10F244: perplexity(llama_context*, gpt_params const&) (perplexity.cpp:62)
==57615==    by 0x10FA0D: main (perplexity.cpp:164)

This is with your latest commits: https://github.com/ggerganov/llama.cpp/pull/1632/commits/2ed107183b1e65f83456931c76a98124cc7ee356

KerfuffleV2 avatar Jun 04 '23 14:06 KerfuffleV2

I said "latest commit" but I lied since you added that while I was composing the message. However https://github.com/ggerganov/llama.cpp/pull/1632/commits/d5d900d75c09f894e3ba0960950ef8b9df7f4aa4 doesn't seem to have made a difference in the behavior: still doesn't actually seem to use cuBLAS and crashes when -ngl is specified. (I didn't compile with debug to test that though because it is sooooo sloowwwwww).

I can retest when you want.

KerfuffleV2 avatar Jun 04 '23 15:06 KerfuffleV2

@KerfuffleV2 Looks like the crash is caused by the lovely uninitialized memory. Please have a try again. BTW, would you please paste the commands for me is possible. With that, I can repeat you did if possible. Thanks!

mqy avatar Jun 04 '23 15:06 mqy

@KerfuffleV2 I assume you had run mulmat-tune bench, would you please share the first 15 lines from the bench results? I want to confirm the bench result format in you platform is as expected.

mqy avatar Jun 04 '23 15:06 mqy

I assume you had run mulmat-tune bench

Sorry, no... I didn't know it was necessary. Is crashing/not working properly the expect behavior in that case? I apologize if I wasted your time.

I had some trouble trying to compile it:

cc  -I.              -O3 -std=c11   -fPIC -DNDEBUG -Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion -Wshadow -Wstrict-prototypes -Wpointer-arith -pthread -march=native -mtune=native -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I/opt/cuda/targets/x86_64-linux/include examples/mulmat-tune/mulmat-tune.c ggml.o ggml-cuda.o ggml-tune.o ggml-threading.o -o mulmat-tune -lm -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L/opt/cuda/targets/x86_64-linux/lib
/usr/bin/ld: ggml-cuda.o:(.data.rel.local.DW.ref.__gxx_personality_v0[DW.ref.__gxx_personality_v0]+0x0): undefined reference to `__gxx_personality_v0'
collect2: error: ld returned 1 exit status
make: *** [Makefile:275: mulmat-tune] Error 1

It needs to be linked with -lstdc++ on Linux at least:

$(CC)  $(CFLAGS) -lstdc++ $^ -o mulmat-tune $(LDFLAGS)

However, it crashes when I try to run it:

./mulmat-tune bench 
[bench] model: 7B, type: Q4_0, backend: CUDA

4096 4096 1 ..GGML_ASSERT: ggml-cuda.cu:873: ggml_cuda_can_mul_mat(src0, src1, dst)
[1]    69219 IOT instruction (core dumped)  ./mulmat-tune bench

I tried compiling it with OpenBLAS and it does run so I don't think I screwed stuff up this time at least.

BTW, would you please paste the commands for me is possible.

Basically just according to the comments in the perplexity example:

./perplexity -f /path/to/wiki.test.raw -m /path/to/llama-7b.ggmlv3.q4_0.bin -t 6 -ngl 10

It does require downloading that wikitest stuff, there's a link to it the source for perplexity.

KerfuffleV2 avatar Jun 04 '23 16:06 KerfuffleV2

Sorry, no... I didn't know it was necessary. Is crashing/not working properly the expect behavior in that case? I apologize if I wasted your time.

Never mind, I really need feedback and help indeed. Looks like the assertion error is cased by clearing tensor.backend in previous commit, I reverted.

I had to say that CUDA/CL related parts is far from bug free at present because the scheduling rule was changed in this PR, unfortunately I haven't dive deeply into the details and unable to test, I'll take some time at this part.

mqy avatar Jun 04 '23 16:06 mqy

Looks like the assertion error is cased by clearing tensor.backend in previous commit, I reverted.

Unfortunately it doesn't seem to have fixed the issue:

./mulmat-tune bench
[bench] model: 7B, type: Q4_0, backend: CUDA

4096 4096 1 ..GGML_ASSERT: ggml-cuda.cu:873: ggml_cuda_can_mul_mat(src0, src1, dst)
[1]    85454 IOT instruction (core dumped)  ./mulmat-tune bench

That is compiled with https://github.com/ggerganov/llama.cpp/pull/1632/commits/9c50185a6049f340c328473af1fa296a041e1c11

I had to say that CUDA/CL related parts is far from bug free at present because the scheduling rule was changed in this PR, unfortunately I haven't dive deeply into the details and unable to test, I'll take some time at this part.

No problem, just though I'd check out this pull. I'm not complaining about it if it doesn't work currently or anything like that.

Unfortunately, helping you fix this kind of stuff is beyond me but if make changes you want someone with an Nvidia card to test later on just @ me and I'll probably be able to help.

KerfuffleV2 avatar Jun 04 '23 17:06 KerfuffleV2

Quick update: I checked out the latest version and gave it another try.

mulmat-tune bench now runs, however it doesn't seem to use cuBLAS.

[bench] model: 7B, type: Q4_0

I double checked and there's no GPU utilization while the test runs. I'm guessing you probably just have cuBLAS disabled for that stuff, but I thought I'd provide this information just in case it's helpful.

KerfuffleV2 avatar Jun 06 '23 16:06 KerfuffleV2

I double checked and there's no GPU utilization while the test runs. I'm guessing you probably just have cuBLAS disabled for that stuff, but I thought I'd provide this information just in case it's helpful.

@KerfuffleV2 Thanks for the feedback!

cuBLAS is supported. I can bench ClBlast, but did not test CUDA yet. Please check your generated bench data file, if CUDA is built into mulmat-tune, it SHOULD run and output something like this:

16 0 0 0  16 1 0 1   0 0 0 0
16 1 0 2  17 0 1 0   0 0 0 0
 0 0 0 0  34 0 1 0   0 0 0 0

This is example profiles for CL, where 16 stands for CPU, 17 for BLAS, 33 for CUDA, 34 for CL. If you did not see 33 in the file, that MUST be a bug.

I'm still working on mulmat-tune, try fixing various stupid bugs or corner cases.

mqy avatar Jun 06 '23 19:06 mqy

I'm sorry, actually it did work. I just stopped it too early before I guess (since previously it actually said when it was using a CUDA backend). Seems like it runs the tests on CPU first. Here are the results I got: https://gist.github.com/KerfuffleV2/03142913e3eed6144f00ae6010c893e5

Is there anything you'd currently like me to test with CUDA since that part seems okay now?


I tried with the perplexity example again. It seems to run successfully (I had to save the test results to mulmat-tune.txt. The test was run on a 7B Q4_0 LLaMA model (same as mulmate-tune benched):

mulmat-tune: loaded file ./mulmat-tune.txt

system_info: n_threads = 6 / 12 | AVX = 1 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | VSX = 0 | 
perplexity: calculating perplexity over 655 chunks, batch_size=512
perplexity: 7.05 seconds per pass - ETA 1 hours 16 minutes
[1]4.4543,[2]4.9401,[3]5.8279,[4]6.4844,[5]6.5856,^C

vs master:

system_info: n_threads = 6 / 12 | AVX = 1 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | VSX = 0 | 
perplexity: calculating perplexity over 655 chunks, batch_size=512
perplexity: 3.84 seconds per pass - ETA 41 minutes
[1]4.4543,[2]4.9401,[3]5.8278,[4]6.4844,[5]6.5855,^C

KerfuffleV2 avatar Jun 06 '23 19:06 KerfuffleV2

Is there anything you'd currently like me to test with CUDA since that part seems okay now?

The numbers in data file looks pretty good, CUDA should almost compete CPU at all batch sizes, but the perplexity shows mulmat-tune is very far too slow. I'll take a look at the threading part later. Would you please try less n_threads: 1, 2, 4?

For how to run perplexity with less time, please have a look at the section Compare With Master from README.

mqy avatar Jun 06 '23 19:06 mqy

Forced a rebase on to master. Tested ACCELERATE, CLBLAST (tune, ngl). Ever tested OpenBLAS and BLIS, so they should work.

CLBLAST is slow on my device, I just run mock tests to ensure expected use cases. With 3B, I couldn't believe CL no longer crash on my device due to OOM, thank you @0cc4m and others.

Basically this PR is not aiming at optimizing hybrid GPU+GPU solutions, but it is inevitable to deal with certain architecture-level designs and trade-offs. Honest speaking, there must have several important decisions to make @ggerganov

To celebrate this milestone, let me show you actual data collected before last rebasing.

Based on 7B, Q4_0, ACCELERATE, with 4 threads.

M Master (2d43387d) Incoming (no tune) Incoming (tune)
8 96.99 96.07 97.37
16 94.44 93.63 95.70
24 99.95 94.15 95.53
32 247.04 158.76 93.58
48 191.63 123.73 98.12
64 137.09 85.69 88.05
96 137.25 87.66 74.53
128 83.74 58.77 58.61
perplexity

Detailed steps and results are listed in examples/mulmat-tune/README.md.

Appreciate to @KerfuffleV2 for evaluating with cuBLAS.

For now, I can think the following todos by me:

  • ensure not breaking CUDA/CL
  • ensure no obvious performance drops, will take carefully with related GPU features.
  • review threading

Welcome evaluations, feedbacks and any suggestions, thank you!

mqy avatar Jun 06 '23 22:06 mqy

Would you please try less n_threads: 1, 2, 4?

Unfortunately, with the latest changes we're back to running into an assertion failure:

ggml_init_cublas: found 1 CUDA devices:
  Device 0: NVIDIA GeForce GTX 1060 6GB
[bench] model: 7B, type: Q4_0

4096 4096 1 ..GGML_ASSERT: ggml.c:10024: false
[1]    113934 IOT instruction (core dumped)  ./mulmat-tune bench

I actually tried running perplexity first but when it failed I thought I might have to regenerate the mulmat-tune results.

perplexity: calculating perplexity over 655 chunks, batch_size=512
GGML_ASSERT: ggml.c:14241: false

KerfuffleV2 avatar Jun 07 '23 07:06 KerfuffleV2

Unfortunately, with the latest changes we're back to running into an assertion failure:

@KerfuffleV2 sorry, my fault, this assertion failure should be fixed with latest commit

Edit: I pushed another fix for CUDA

mqy avatar Jun 07 '23 08:06 mqy

GGML_ASSERT: ggml.c:10034: comp_backend & GGML_TASK_BACKEND_CPU

Looks like ggml_compute_forward_mul_mat_q_f32 may need a similar change?

KerfuffleV2 avatar Jun 07 '23 08:06 KerfuffleV2