ggml icon indicating copy to clipboard operation
ggml copied to clipboard

ggml : get rid of BLAS and all it's variants

Open ggerganov opened this issue 2 years ago • 39 comments

This is a big one

The only reason we use BLAS is that we don't have efficient implementation of matrix x matrix multiplication. Naively doing parallel dot products is not optimal. We need to implement some of the fundamental GEMM optimizations such as block tiling and we need to implement this in a compact way that reuses the existing dot product code and supports all quantization types

More comments on this:

  • https://github.com/ggerganov/llama.cpp/issues/1867#issuecomment-1595702365
  • https://github.com/ggerganov/llama.cpp/pull/1935#issuecomment-1597140738
  • https://github.com/ggerganov/ggml/issues/293#issuecomment-1607387005

ggerganov avatar Jun 25 '23 09:06 ggerganov

supports all quantization types

The best place for dequantization would be after you've loaded data from device -> threadgroup memory, and are loading into registers. I hypothesize it will faster to have each of the four simds in a threadgroup unpack on their own, even if that duplicates the work two times.

For example, you might fork my MFA repository, then modify the line here to unpack after loading.

https://github.com/philipturner/metal-flash-attention/blob/fbfd0a028402c0ae6fa293c2f91de318b95e359b/Sources/GEMM.metal#L129

philipturner avatar Jun 25 '23 21:06 philipturner

While I do understand the desire for the project to be independend of other libraries, I personally do not think removing the excellent cuBLAS implementation entirely is a good idea. One of the main benefits of this project is that you can run large models with good speeds on lower end hardware with less VRAM. For example, thanks to the cuBLAS' fast prompt processing of 15 ms/t I can enjoy a 13B parameter model with full context at around 1.6 token/sec, which is much faster than running GPTQ with cpu offloading (0,4 token/s) on my RTX 2060 laptop.

As you've said yourself, a native GEMM implementation would likely be slower than what cuBLAS is offering. Even if the performance difference is not drastic on your hardware, it can make all the difference on hardware like mine, even a slight performance difference (for example 15 ms/t to 23 ms/t) would lead to a worse experience for people who run huge LLMs too big for their systems. Not just for low to mid spec hardware, but for high end hardware trying to run 65B models as well.

Please keep this perspective in mind as you continue to develop the project. I don't think it would be a good outcome for me and many others to have to downgrade to older versions with cuBLAS and not enjoy new enhancements just because older versions run faster due to cuBLAS support.

Dampfinchen avatar Jun 26 '23 09:06 Dampfinchen

Why would custom code be slower than cuBLAS? There is Nvidia CUTLASS. The only things it can be are equal or faster.

philipturner avatar Jun 26 '23 12:06 philipturner

Why would custom code be slower than cuBLAS? There is Nvidia CUTLASS. The only things it can be are equal or faster.

They aren't going to use CUTLASS though, because it'd be another third party lib.

CuBLAS is highly optimized for the hardware and Georgi said himself that he is aware a custom code is not going to perform as well in this comment:

https://github.com/ggerganov/llama.cpp/issues/1867#issuecomment-1595702365

And I fully understand that it will be close to impossible to achieve the maximum performance available from dedicated libraries (such as cuBLAS, for example).

Dampfinchen avatar Jun 26 '23 12:06 Dampfinchen

300 lines of code, (soon to) outperform all of Apple's proprietary Metal Performance Shaders.

https://github.com/philipturner/metal-flash-attention/blob/main/Sources/GEMM.metal

If we take some code from CUTLASS, maybe just optimize for the matrix shapes that exist in LLaMA.

philipturner avatar Jun 26 '23 12:06 philipturner

A custom GEMM implementation will be faster with quantized models - that's one of the goals. There may be a small performance regression with f16 and f32 models, though.

slaren avatar Jun 26 '23 12:06 slaren

Could we utilize the int4 hardware support in Ampere tensor cores?

philipturner avatar Jun 26 '23 12:06 philipturner

From what I have seen of the way int4 works with tensor cores, I don't think so. We cannot do a matrix multiplication directly in int4, we need to dequantize to f16 or f32 first. But we can still use the tensor cores after dequantizing to float. I may be wrong, though.

slaren avatar Jun 26 '23 12:06 slaren

A custom GEMM implementation will be faster with quantized models - that's one of the goals. There may be a small performance regression with f16 and f32 models, though.

Glad to hear that. I hope that will be the case.

Could we utilize the int4 hardware support in Ampere tensor cores?

Just a heads up: Tensor cores in Turing, Ampere and Ada Lovelace support INT4, INT8 and FP16 instructions. Ampere and Turing support INT1 as well.

While only Ampere and Ada support FP8 and FP32 in addition to that.

So ideally, the code would use FP16 or the integer instructions I mentioned to cover a wide range of hardware with tensor core support.

Dampfinchen avatar Jun 26 '23 12:06 Dampfinchen

Didn't they remove INT4 on Ada and Hopper?

philipturner avatar Jun 26 '23 13:06 philipturner

Didn't they remove INT4 on Ada and Hopper?

Only on Hopper. INT4 is still present in consumer Ada Lovelace.

https://images.nvidia.com/aem-dam/Solutions/Data-Center/l4/nvidia-ada-gpu-architecture-whitepaper-v2.1.pdf Page 24:

Compared to Ampere, Ada delivers more than double the FP16, BF16, TF32, INT8, and INT4 Tensor TFLOPS, and also includes the Hopper FP8 Transformer Engine, delivering over 1.3 PetaFLOPS of tensor processing in the RTX 4090.

It looks like Ada removed support for INT1 though.

Dampfinchen avatar Jun 26 '23 13:06 Dampfinchen

I'd tend to agree with @Dampfinchen : being interested in the Intel platform I don't believe we'll be able to outperform the MKL and oneAPI engineers.

goerch avatar Jun 26 '23 19:06 goerch

I tried implementing a (dequantization +) matrix matrix multiplication CUDA kernel but I'm struggling to get past 50% of cuBLAS performance for prompt processing. In particular, I found against my expectation that fusing dequantization + matrix matrix multiplication does not have a large impact on performance, possibly because you're limited by compute rather than memory bandwidth for large matrices (I am currently not using tensor cores).

What level of performance/sophistication is the goal for something that could possibly be merged? As of right now my implementation could already be useful for token generation since the VRAM usage will be lower compared to cuBLAS. But for prompt processing it is clearly worse. In general I think leaving cuBLAS as a compilation option would be desirable because given the small impact of tensor fusion I don't think I can realistically beat it for prompt processing performance.

JohannesGaessler avatar Jun 26 '23 21:06 JohannesGaessler

50% of cuBLAS performance for prompt processing.

If you get 50% of cuBLAS performance, then what is cuBLAS performance in ALU utilization? Perhaps both underutilize the processor, leaving much room to improve.

philipturner avatar Jun 26 '23 22:06 philipturner

If y'all progressively get rid of blas libraries, cublas is probably lowest on the totem pole? AFAIK users still need the huge cuda toolkit to run cuda inference anyway, so its hardly even getting rid of a dependency.

CLBLAST and CPU BLAS, on the other hand, can be tricky, but their implementations are open source. I just tried to get the OpenBLAS build working on an Ampere instance for a few hours... and ultimately failed.

AlphaAtlas avatar Jun 26 '23 23:06 AlphaAtlas

Why not just write the entire thing in one Mojo file.

philipturner avatar Jun 26 '23 23:06 philipturner

What level of performance/sophistication is the goal for something that could possibly be merged?

Personally, I wouldn't want this enabled by default until the performance with quantized models is at least comparable to cuBLAS. It's ok if performance with f16/f32 models is worse. But I think it could still be useful to have it merged as an option, just disabled by default. It would be a starting point, and we could keep improving it over time until we reach the performance goal.

slaren avatar Jun 27 '23 08:06 slaren

Why not just write the entire thing in one Mojo file.

Because Mojo is vaporware and doesn't actually exist yet? You can't yet download or run Mojo, it is useless for now.

LoganDark avatar Jun 28 '23 03:06 LoganDark

There's a reason they're close-sourcing it for now, the same reason I close-sourced MFA for ~2 months. It's too buggy at the moment and will develop faster the way it is now.

Now I open-source when it is ready, and the decision pays off.

philipturner avatar Jun 28 '23 03:06 philipturner

Because Mojo is vaporware and doesn't actually exist yet?

When I learned what Modular was doing, I quit AI and shifted careers. No cap. There's nothing left for me to do because Modular is going to solve it.

philipturner avatar Jun 28 '23 03:06 philipturner

There's a reason they're close-sourcing it for now, the same reason I close-sourced MFA for ~2 months. It's too buggy at the moment and will develop faster the way it is now.

Now I open-source when it is ready, and the decision pays off.

While I absolutely don't claim to know that it will never release, there is plenty of reason NOT to bet on it just yet considering it isn't public.

It's not just closed source, you can't even download binaries yet.

LoganDark avatar Jun 28 '23 21:06 LoganDark

@ggerganov Can you clarify the current/planned threading model for CPU computation? This seems like it should be central to the discussion... BLAS is multi-threaded and works extremely well when the calling program is single-threaded. GGML appears to use threads to support concurrent execution of unrelated tasks. I wonder if the majority of workflows would be better off with a single-threaded top-level scheduler, with all cores assigned to work on individual large-ish tasks, a la the BLAS computation model.

evanmiller avatar Jul 07 '23 02:07 evanmiller

I must say that this is simply not possible. I recommend reading the paper titled Anatomy of High-Performance Matrix Multiplication, written in 2008. Achieving high performance requires significant sacrifices. Have you looked at the code for Goto BLAS or OpenBLAS? They are all written in assembly! Yes, assembly language! You need to understand intimately how the hardware works and gauge exactly how far you can push it in order to achieve maximum performance. @ggerganov

bobqianic avatar Aug 10 '23 15:08 bobqianic

You would be correct if we were just doing FP32 + FP32 -> FP32 matrix multiplication. But we are not. The matrix is quantized to some custom data format that consists mostly of low-precision integers + some floating point scales. This data format can not be directly used by any regular BLAS library. So currently the quantized data needs to be converted to FP32 first which costs you both compute time and extra memory.

If you were to instead convert the hidden state from FP32 to q8_1 you would also be able to drastically reduce the amount of floating point instructions and replace them with SIMD integer instructions which are much faster. Consider the current state of CUDA mul_mat_q kernels: they use 700/970/1430 MiB less memory than cuBLAS for 7b/13b/33b and they are up to 2x faster (depending on hardware and quantization format). This is not because I can write the absolute best GEMM kernels but simply because I wrote GEMM kernels that take advantage of the specific ggml data format, both in terms of data types and the memory layout.

JohannesGaessler avatar Aug 10 '23 16:08 JohannesGaessler

Great idea! I hadn't previously considered the overhead caused by the custom data format. However, I still believe that while minimizing overhead, we should use these BLAS libraries as much as possible to ensure optimal performance across different hardware. Because unlike CPUs, there are significant architectural differences between GPUs. Even products from the same company can have vast differences between generations. Every time NVIDIA introduces a new GPU architecture, CUDA has to undergo major updates to achieve the best performance on the new hardware. So, a custom-written kernel needs continuous maintenance, and the effort required is substantial. If you really don't want to use BLAS, I suggest you take a look at these: Deep Learning Compilers How Rammer squeezes more out of accelerator performance

bobqianic avatar Aug 11 '23 03:08 bobqianic

I don't think a single file can use all the hardware features of every processor, until Mojo comes around. We don't have a unified language, as CUDA only runs on gaming rigs and high-end laptops (regarding consumer hardware). Metal runs on 1 billion smartphones but is very different. Plus to use simdgroup_async_copy you have to pre-compile offline using a command-line tool from the archived Xcode 14.2 binary (dependency nightmare).

For mat-vec multiplication, it makes sense to dequantize in place. For mat-mat multiplication, dequantizing in-place increases the total number of operations while the ALU is already saturated. Plus, the proposed single-file idea will probably skip important hardware features (e.g. simdgroup_async_copy) that get full ALU saturation in the first place. I've been discussing quantized mat-mat multiplication in another AI application, and we decided on dequantizing to a small scratch MTLBuffer before calling into a pre-compiled FP16 x FP16 GEMM kernel from MFA (not MPS).

NOTE: By single-file I do not mean it has to literally be a single file. But that is the general sense of what this idea seems to be close to.

philipturner avatar Aug 11 '23 09:08 philipturner

For mat-mat multiplication, dequantizing in-place increases the total number of operations while the ALU is already saturated.

The goal is not to dequantize in place but to quantize the hidden state to q8_1 once per matrix matrix multiplication and to then do the calculations entirely using the quantized formats. This lets you replace floating point arithmetic with integer arithmetic or SIMD instructions so it should end up being faster.

JohannesGaessler avatar Aug 11 '23 13:08 JohannesGaessler

The GPU is already a SIMD architecture. Do you mean an optimization only applicable to CPU? If you're using an entire SIMD vector instruction for one scalar, that's underutilizing the SIMD ALU by a factor proportional to vector width.

philipturner avatar Aug 11 '23 15:08 philipturner

I mean to use this instead of floating point arithmetic.

JohannesGaessler avatar Aug 11 '23 15:08 JohannesGaessler

this might be relevant: https://github.com/ashvardanian/SimSIMD

rawwerks avatar Dec 13 '23 02:12 rawwerks