Add Intel Advanced Matrix Extensions (AMX) support to ggml
This PR improves intel server CPU performance with intel advanced matrix extension (AMX). AMX is a new built-in accelerator for gemm starting from 4th gen Xeon: https://www.intel.com/content/www/us/en/products/docs/accelerator-engines/advanced-matrix-extensions/overview.html
The basic idea is pretty much the same as what i have done in PyTorch https://github.com/pytorch/pytorch/pull/117475 for the int4 and int8 mixed dtype gemms.
Features
- now it supports Q4_0, Q4_1, Q8_0 quantized format ( I picked up these formats based on current support from
__ARM_FEATURE_MATMUL_INT8), more support will be added in the feature. Kernels are placed inggml-amx.cppsince I don't want to mess up with ggml.c which is already very complexed and the amx kernels could also be more complexed in future if add more qformat support. - implement fast weight only quantized gemm kernels with amx, q4 will be unpacked to q8 and doing s8s8 or u8s8 gemm.
- implement a fast path when batch dimension is 1 for gemv, when batch dimension is small vnni is usually faster than amx because amx has larger overhead.
- implement double buffering for post processing (applying scales): the scales are stored in f16 in gguf and the tmul has to be interleaved with f32 instructions. So it won't hit the hardware compute limit as we can not run tmul in series. Applying double buffering here improves 8%-11% performance.
- the amx and vnni kernels have numerically identical results with current avx or avx2 kernels from
ggml-quant.c. - amx and kernels will be compiled automatically on CPUs with the hardware support, otherwise not.
Performance
- results from llama2-7b-q4_0, about 2x speed up for the text generation. Collected on Intel (R) Xeon (R) CPU Max 9480:
- before
llama_print_timings: load time = 533.79 ms
llama_print_timings: sample time = 7.29 ms / 200 runs ( 0.04 ms per token, 27453.67 tokens per second)
llama_print_timings: prompt eval time = 77.35 ms / 6 tokens ( 12.89 ms per token, 77.57 tokens per second)
llama_print_timings: eval time = 9333.20 ms / 199 runs ( 46.90 ms per token, 21.32 tokens per second)
llama_print_timings: total time = 9487.99 ms / 205 tokens
- after
llama_print_timings: load time = 549.56 ms
llama_print_timings: sample time = 3.73 ms / 96 runs ( 0.04 ms per token, 25751.07 tokens per second)
llama_print_timings: prompt eval time = 67.38 ms / 6 tokens ( 11.23 ms per token, 89.05 tokens per second)
llama_print_timings: eval time = 2245.79 ms / 95 runs ( 23.64 ms per token, 42.30 tokens per second)
llama_print_timings: total time = 2346.99 ms / 101 tokens
- results from
benchmark-matmult(metic: gFlops):
| cores | before | after | speedup |
|---|---|---|---|
| 1 | 50.67 | 260.04 | 5.13 |
| 4 | 171.92 | 1026.29 | 5.97 |
| 16 | 192.38 | 2143.1 | 11.14 |
| 32 | 263.7 | 3694.85 | 14.01 |
TODO:
- ~add more quantized dtype support~
- ~add bf16 gemm support with amx-bf16 (using avx512-bf16 for gemv)~
- ~add f16 gemm support with amx-f16 (using avx512-f16 for gemv)~
I also noticed from vtune that some pointwise operators need additional optimization, e.g. softmax, etc. Will handle them later on.
This PR also adds openmp support since the original pthead sync is done via atomic which has a very high overhead on server CPUs (and the sync has to be done very frequently for each operator launch). This is not my initial target but I have to fix it by using other threading runtimes, openmp or tbb. Otherwise the performance speedup will be cut off quite q bit.
I noticed https://github.com/ggerganov/llama.cpp/pull/7606 is also doing this, this should also work.
BTW why AMX will greatly improve next token latency?
📈 llama.cpp server for bench-server-baseline on Standard_NC4as_T4_v3 for phi-2-q4_0: 555 iterations 🚀
Expand details for performance related PR only
- Concurrent users: 8, duration: 10m
- HTTP request : avg=8404.11ms p(95)=20501.82ms fails=, finish reason: stop=500 truncated=55
- Prompt processing (pp): avg=89.57tk/s p(95)=384.56tk/s
- Token generation (tg): avg=34.36tk/s p(95)=47.48tk/s
- ggml-org/models/phi-2/ggml-model-q4_0.gguf parallel=8 ctx-size=16384 ngl=33 batch-size=2048 ubatch-size=256 pp=1024 pp+tg=2048 branch=pr_add_amx_support_1 commit=952af436ea0c5717f06e701108f4b12b93c58260
More
---
config:
xyChart:
titleFontSize: 12
width: 900
height: 600
themeVariables:
xyChart:
titleColor: "#000000"
---
xychart-beta
title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
duration=10m 555 iterations"
y-axis "llamacpp:prompt_tokens_seconds"
x-axis "llamacpp:prompt_tokens_seconds" 1717437224 --> 1717437854
line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 667.98, 667.98, 667.98, 667.98, 667.98, 873.64, 873.64, 873.64, 873.64, 873.64, 852.92, 852.92, 852.92, 852.92, 852.92, 867.1, 867.1, 867.1, 867.1, 867.1, 918.08, 918.08, 918.08, 918.08, 918.08, 910.25, 910.25, 910.25, 910.25, 910.25, 901.32, 901.32, 901.32, 901.32, 901.32, 927.72, 927.72, 927.72, 927.72, 927.72, 918.49, 918.49, 918.49, 918.49, 918.49, 930.08, 930.08, 930.08, 930.08, 930.08, 943.96, 943.96, 943.96, 943.96, 943.96, 926.06, 926.06, 926.06, 926.06, 926.06, 941.81, 941.81, 941.81, 941.81, 941.81, 925.1, 925.1, 925.1, 925.1, 925.1, 860.14, 860.14, 860.14, 860.14, 860.14, 858.55, 858.55, 858.55, 858.55, 858.55, 859.54, 859.54, 859.54, 859.54, 859.54, 856.91, 856.91, 856.91, 856.91, 856.91, 876.7, 876.7, 876.7, 876.7, 876.7, 875.29, 875.29, 875.29, 875.29, 875.29, 870.65, 870.65, 870.65, 870.65, 870.65, 876.07, 876.07, 876.07, 876.07, 876.07, 876.06, 876.06, 876.06, 876.06, 876.06, 890.99, 890.99, 890.99, 890.99, 890.99, 891.11, 891.11, 891.11, 891.11, 891.11, 893.37, 893.37, 893.37, 893.37, 893.37, 904.6, 904.6, 904.6, 904.6, 904.6, 904.19, 904.19, 904.19, 904.19, 904.19, 904.11, 904.11, 904.11, 904.11, 904.11, 908.5, 908.5, 908.5, 908.5, 908.5, 907.48, 907.48, 907.48, 907.48, 907.48, 903.81, 903.81, 903.81, 903.81, 903.81, 903.97, 903.97, 903.97, 903.97, 903.97, 910.2, 910.2, 910.2, 910.2, 910.2, 914.47, 914.47, 914.47, 914.47, 914.47, 920.0, 920.0, 920.0, 920.0, 920.0, 920.72, 920.72, 920.72, 920.72, 920.72, 916.81, 916.81, 916.81, 916.81, 916.81, 918.08, 918.08, 918.08, 918.08, 918.08, 919.48, 919.48, 919.48, 919.48, 919.48, 920.75, 920.75, 920.75, 920.75, 920.75, 916.89, 916.89, 916.89, 916.89, 916.89, 916.14, 916.14, 916.14, 916.14, 916.14, 915.48, 915.48, 915.48, 915.48, 915.48, 912.62, 912.62, 912.62, 912.62, 912.62, 915.1, 915.1, 915.1, 915.1, 915.1, 914.85, 914.85, 914.85, 914.85, 914.85, 913.89, 913.89, 913.89, 913.89, 913.89, 913.84, 913.84, 913.84, 913.84, 913.84, 915.62, 915.62, 915.62, 915.62, 915.62, 917.67, 917.67, 917.67, 917.67, 917.67, 920.55, 920.55, 920.55, 920.55, 920.55, 925.42, 925.42, 925.42, 925.42, 925.42, 923.07, 923.07, 923.07, 923.07, 923.07, 923.82, 923.82, 923.82, 923.82, 923.82, 923.3, 923.3, 923.3, 923.3, 923.3, 920.75, 920.75, 920.75, 920.75, 920.75, 922.83, 922.83, 922.83, 922.83, 922.83, 924.12, 924.12, 924.12, 924.12, 924.12, 922.48, 922.48, 922.48, 922.48, 922.48, 922.29]
More
---
config:
xyChart:
titleFontSize: 12
width: 900
height: 600
themeVariables:
xyChart:
titleColor: "#000000"
---
xychart-beta
title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
duration=10m 555 iterations"
y-axis "llamacpp:predicted_tokens_seconds"
x-axis "llamacpp:predicted_tokens_seconds" 1717437224 --> 1717437854
line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 32.83, 32.83, 32.83, 32.83, 32.83, 32.87, 32.87, 32.87, 32.87, 32.87, 30.98, 30.98, 30.98, 30.98, 30.98, 31.79, 31.79, 31.79, 31.79, 31.79, 32.29, 32.29, 32.29, 32.29, 32.29, 32.37, 32.37, 32.37, 32.37, 32.37, 33.56, 33.56, 33.56, 33.56, 33.56, 33.92, 33.92, 33.92, 33.92, 33.92, 34.12, 34.12, 34.12, 34.12, 34.12, 34.14, 34.14, 34.14, 34.14, 34.14, 33.81, 33.81, 33.81, 33.81, 33.81, 33.45, 33.45, 33.45, 33.45, 33.45, 32.91, 32.91, 32.91, 32.91, 32.91, 32.52, 32.52, 32.52, 32.52, 32.52, 31.97, 31.97, 31.97, 31.97, 31.97, 30.86, 30.86, 30.86, 30.86, 30.86, 30.1, 30.1, 30.1, 30.1, 30.1, 30.27, 30.27, 30.27, 30.27, 30.27, 30.35, 30.35, 30.35, 30.35, 30.35, 29.9, 29.9, 29.9, 29.9, 29.9, 29.72, 29.72, 29.72, 29.72, 29.72, 29.81, 29.81, 29.81, 29.81, 29.81, 30.06, 30.06, 30.06, 30.06, 30.06, 30.15, 30.15, 30.15, 30.15, 30.15, 30.3, 30.3, 30.3, 30.3, 30.3, 30.66, 30.66, 30.66, 30.66, 30.66, 30.52, 30.52, 30.52, 30.52, 30.52, 30.66, 30.66, 30.66, 30.66, 30.66, 30.97, 30.97, 30.97, 30.97, 30.97, 31.12, 31.12, 31.12, 31.12, 31.12, 31.22, 31.22, 31.22, 31.22, 31.22, 31.24, 31.24, 31.24, 31.24, 31.24, 31.32, 31.32, 31.32, 31.32, 31.32, 31.37, 31.37, 31.37, 31.37, 31.37, 31.15, 31.15, 31.15, 31.15, 31.15, 31.02, 31.02, 31.02, 31.02, 31.02, 30.77, 30.77, 30.77, 30.77, 30.77, 30.73, 30.73, 30.73, 30.73, 30.73, 30.85, 30.85, 30.85, 30.85, 30.85, 30.98, 30.98, 30.98, 30.98, 30.98, 31.17, 31.17, 31.17, 31.17, 31.17, 31.0, 31.0, 31.0, 31.0, 31.0, 30.93, 30.93, 30.93, 30.93, 30.93, 30.68, 30.68, 30.68, 30.68, 30.68, 29.38, 29.38, 29.38, 29.38, 29.38, 29.38, 29.38, 29.38, 29.38, 29.38, 29.36, 29.36, 29.36, 29.36, 29.36, 29.33, 29.33, 29.33, 29.33, 29.33, 29.38, 29.38, 29.38, 29.38, 29.38, 29.45, 29.45, 29.45, 29.45, 29.45, 29.54, 29.54, 29.54, 29.54, 29.54, 29.48, 29.48, 29.48, 29.48, 29.48, 29.48, 29.48, 29.48, 29.48, 29.48, 29.42, 29.42, 29.42, 29.42, 29.42, 29.55, 29.55, 29.55, 29.55, 29.55, 29.7, 29.7, 29.7, 29.7, 29.7, 29.8, 29.8, 29.8, 29.8, 29.8, 29.94, 29.94, 29.94, 29.94, 29.94, 29.95, 29.95, 29.95, 29.95, 29.95, 29.98, 29.98, 29.98, 29.98, 29.98, 29.97]
Details
More
---
config:
xyChart:
titleFontSize: 12
width: 900
height: 600
themeVariables:
xyChart:
titleColor: "#000000"
---
xychart-beta
title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
duration=10m 555 iterations"
y-axis "llamacpp:kv_cache_usage_ratio"
x-axis "llamacpp:kv_cache_usage_ratio" 1717437224 --> 1717437854
line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.27, 0.27, 0.27, 0.27, 0.27, 0.41, 0.41, 0.41, 0.41, 0.41, 0.12, 0.12, 0.12, 0.12, 0.12, 0.18, 0.18, 0.18, 0.18, 0.18, 0.22, 0.22, 0.22, 0.22, 0.22, 0.2, 0.2, 0.2, 0.2, 0.2, 0.11, 0.11, 0.11, 0.11, 0.11, 0.2, 0.2, 0.2, 0.2, 0.2, 0.13, 0.13, 0.13, 0.13, 0.13, 0.22, 0.22, 0.22, 0.22, 0.22, 0.31, 0.31, 0.31, 0.31, 0.31, 0.26, 0.26, 0.26, 0.26, 0.26, 0.24, 0.24, 0.24, 0.24, 0.24, 0.31, 0.31, 0.31, 0.31, 0.31, 0.36, 0.36, 0.36, 0.36, 0.36, 0.31, 0.31, 0.31, 0.31, 0.31, 0.14, 0.14, 0.14, 0.14, 0.14, 0.13, 0.13, 0.13, 0.13, 0.13, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.13, 0.13, 0.13, 0.13, 0.13, 0.21, 0.21, 0.21, 0.21, 0.21, 0.15, 0.15, 0.15, 0.15, 0.15, 0.28, 0.28, 0.28, 0.28, 0.28, 0.11, 0.11, 0.11, 0.11, 0.11, 0.14, 0.14, 0.14, 0.14, 0.14, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.15, 0.15, 0.15, 0.15, 0.15, 0.2, 0.2, 0.2, 0.2, 0.2, 0.15, 0.15, 0.15, 0.15, 0.15, 0.16, 0.16, 0.16, 0.16, 0.16, 0.3, 0.3, 0.3, 0.3, 0.3, 0.24, 0.24, 0.24, 0.24, 0.24, 0.28, 0.28, 0.28, 0.28, 0.28, 0.32, 0.32, 0.32, 0.32, 0.32, 0.15, 0.15, 0.15, 0.15, 0.15, 0.16, 0.16, 0.16, 0.16, 0.16, 0.12, 0.12, 0.12, 0.12, 0.12, 0.29, 0.29, 0.29, 0.29, 0.29, 0.47, 0.47, 0.47, 0.47, 0.47, 0.51, 0.51, 0.51, 0.51, 0.51, 0.5, 0.5, 0.5, 0.5, 0.5, 0.26, 0.26, 0.26, 0.26, 0.26, 0.25, 0.25, 0.25, 0.25, 0.25, 0.22, 0.22, 0.22, 0.22, 0.22, 0.22, 0.22, 0.22, 0.22, 0.22, 0.15, 0.15, 0.15, 0.15, 0.15, 0.09, 0.09, 0.09, 0.09, 0.09, 0.15, 0.15, 0.15, 0.15, 0.15, 0.21, 0.21, 0.21, 0.21, 0.21, 0.25, 0.25, 0.25, 0.25, 0.25, 0.13, 0.13, 0.13, 0.13, 0.13, 0.08, 0.08, 0.08, 0.08, 0.08, 0.11, 0.11, 0.11, 0.11, 0.11, 0.07, 0.07, 0.07, 0.07, 0.07, 0.15, 0.15, 0.15, 0.15, 0.15, 0.19, 0.19, 0.19, 0.19, 0.19, 0.22, 0.22, 0.22, 0.22, 0.22, 0.28]
More
---
config:
xyChart:
titleFontSize: 12
width: 900
height: 600
themeVariables:
xyChart:
titleColor: "#000000"
---
xychart-beta
title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
duration=10m 555 iterations"
y-axis "llamacpp:requests_processing"
x-axis "llamacpp:requests_processing" 1717437224 --> 1717437854
line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 3.0, 3.0, 3.0, 3.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 2.0, 2.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0, 2.0, 2.0, 2.0, 2.0, 2.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 2.0, 2.0, 2.0, 2.0, 2.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 2.0, 2.0, 2.0, 2.0, 2.0, 6.0, 6.0, 6.0, 6.0, 6.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 3.0, 3.0, 3.0, 3.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0, 5.0, 2.0, 2.0, 2.0, 2.0, 2.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 6.0, 6.0, 6.0, 6.0, 6.0, 4.0, 4.0, 4.0, 4.0, 4.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 1.0]
Here is my suggestion:
-
update the README.md Explain the condition that AMX will be used to speed up inference, like hardware, build parameter. How to run CI
-
CMake (optional) Most backends use CMake to build. It's great if support CMake to build it.
BTW why AMX will greatly improve next token latency?
I also wrote an vnni kennel for gemv cases.
Here is my suggestion:
- update the README.md Explain the condition that AMX will be used to speed up inference, like hardware, build parameter. How to run CI
- CMake (optional) Most backends use CMake to build. It's great if support CMake to build it.
sure, the BKMs for intel also need to be updated.
Updates: f16 support added
Right now this patch only has a avx512 kernel which is doing fma with avx512f (did not use avx512-fp16 here as _mm512_fmadd_ph is doing accumulation with 16 bits). The amx kernel will be added once the 6th gen Xeon is released which has amx-f16 support).
Also postpone bf16 amx kernels support to align with f16 amx timeline. Since bf16 is not that common is gguf.
Performance: Tested on Meta-Llama-3-8B-Instruct-fp16.gguf, about 1.6x performance improvement in test generation, tested on Intel (R) Xeon (R) CPU Max 9480 (the 4th gen Xeon):
### before:
llama_print_timings: eval time = 20867.58 ms / 199 runs ( 104.86 ms per token, 9.54 tokens per second)
### after:
llama_print_timings: eval time = 13214.77 ms / 199 runs ( 66.41 ms per token, 15.06 tokens per second)
Now that we have the necessary capabilities to implement backends such as the BLAS backend (#6210), would it make sense to implement a similar AMX backend and put this implementation there. @slaren what do you think?
I experimented with this in https://github.com/ggerganov/llama.cpp/commit/f3974cabac841b9b35283a731bc31203288633d4 by moving all matrix multiplication to the BLAS backend. Generally I think the performance is ok, maybe 1-2% slower for small models (<1B), but the difference is very small for most models.
I think the implementation is good! Suggest to update the README.md :
- guide how to enable this feature.
- the condition to enable it.
- Relationship with AVX512, VNNI and AMX.
There is only miss the guide of this feature.
Thank you!
I think it would be better to leave the implementation as is instead of moving it to a different backend, the performance would be slightly better, and I don't really see a good reason to split the CPU backend into multiple backends. The changes in ggml.c should not be necessary now that openmp is already used by default.
I think it would be better to leave the implementation as is instead of moving it to a different backend, the performance would be slightly better, and I don't really see a good reason to split the CPU backend into multiple backends. The changes in
ggml.cshould not be necessary now that openmp is already used by default.
AMX is a new built-in accelerator available from the 4th generation of Xeon, the Intel sever CPU, link. So this PR is actually trying to improve the performance of llama.cpp on intel server CPUs. And AMX is not equal to the concept of BLAS.
@slaren I don't quite get your idea, should I continue with ggml-amx.cpp or move the optimizations to somewhere else?
My general idea is putting all the AMX related optimizations in a single file which would be easier to maintain. The current available Xeons (the 4th gen and the 5th gen) have the same ISA, but the 6th gen Xeon has two different types: E core and P core. 6th gen Xeon will be launched very soon, so I need to update the AMX related optimizations for the new hardware in near future: adding amx-f16 kernels.
OMP changes in ggml.c shall be gone after rebasing. Currently I am working on QK_K AMX kernels and I will clear up this PR once it is done.
should I continue with
ggml-amx.cppor move the optimizations to somewhere else?
I was responding to @ggerganov suggestion to move the implementation to a different backend similar to the BLAS backend. I think you should continue as is.
Ok, let's proceed as is
Is there any progress? I am really looking forward to the AMX support.
Is there any progress? I am really looking forward to the AMX support.
Recently I got distracted by some other tasks, I use my spare time to work on this project as this is not an official task from my employer. Currently I am working on the Q4K quant format, have to say that it is much more complexed... Anyway it's about to be finished.
Added AMX and VNNI kernels for Q4_K, Q5_K, Q6_K, IQ4_XS.
I like that the code is very well isolated from the rest of the codebase. Haven't reviewed the
mmq.cppsource in details yet and it would be difficult without the appropriate hardware, but I think that's alright as we can easily determine that there won't be side effects to the rest of the codebase.Wondering how we could add some tests for this functionality.
Overall, seems good to me. @slaren What do you think?
Yeah... the CI is a big problem. I will try to find some internal sponsor and then we can use our company cloud, that would be the best. Otherwise, we will have to go the emulator.
@ggerganov i was wondering how the Ascend910B3 functionalities are tested in the CI ? Does Huawei provides the CI support ? And also how about the aarch64 functionalities for arm servers ?
We don't have CI for the CANN backend either. For aarch64, I'm planning to try to rent an Arm machine on the Azure cloud when they become available and if they are not too expensive
I think the Sapphire Rapids Xeon (4th generation Xeon) support AMX. In Azure, DCesv5-series and DCedsv5-series are powered by Intel® 4th Generation Xeon® Scalable processors (https://learn.microsoft.com/en-us/azure/virtual-machines/dcesv5-dcedsv5-series). They should support AMX. It possible to build CI on it.
Hi, I noticed some quantization issues in mmq.cpp.
https://github.com/mingfeima/llama.cpp/blob/74bb1eb52be7d9b9eb484d156d24a474dd09f278/ggml/src/ggml-amx/mmq.cpp#L1183-L1195
Here, we are using a single scale vd0 for all 16x32 weights. However, Q8_0 uses scale parameter per blck_size=32 elements.
- Is this compensated by agreeing with a single scale parameter for all 16x32 weights? I don't see any code in
pack_B, etc doing so. - Wouldn't this be another quantization method if we use a single scale for 512 weights? This would have different result compared with existing Q8_0 AVX-based methods.
Hi, I noticed some quantization issues in
mmq.cpp. https://github.com/mingfeima/llama.cpp/blob/74bb1eb52be7d9b9eb484d156d24a474dd09f278/ggml/src/ggml-amx/mmq.cpp#L1183-L1195 Here, we are using a single scalevd0for all 16x32 weights. However, Q8_0 uses scale parameter perblck_size=32elements.
- Is this compensated by agreeing with a single scale parameter for all 16x32 weights? I don't see any code in
pack_B, etc doing so.- Wouldn't this be another quantization method if we use a single scale for 512 weights? This would have different result compared with existing Q8_0 AVX-based methods.
the weight packing for Q8_0 is here https://github.com/mingfeima/llama.cpp/blob/74bb1eb52be7d9b9eb484d156d24a474dd09f278/ggml/src/ggml-amx/mmq.cpp#L866-L873
each weight block of 16x32 (NxK) is stored in the format of (KxN) so that we can do FMA here, and this block will have 16 scales (d0), it is packed as a contiguous vector of 1x16, the dtype is f16. So to sum up, the scale is a 256-bit vector which corresponds to 16 columns. So it is not a "single scale parameter for all 16x32 weights". If the computation is wrong, the llm will talk like crazy.
@ggerganov On Azure, DCesv5 and ECesv5 instances have intel AMX support, they are all 4th gen Xeon (codename Sapphire Rapids):
https://azure.microsoft.com/en-us/updates/confidential-vms-with-intel-tdx-dcesv5-ecesv5/
https://learn.microsoft.com/en-us/azure/virtual-machines/sizes/general-purpose/dcesv5-series?tabs=sizebasic
Is that possible to use those instances for CI ? The amx features will be compiled on CPUs with AMX support by default, the requirement for minimum gcc version is gcc-11.
Thanks for letting me know - I just added an AMX VM (EC8eds v5) to the ggml-ci fleet:
https://github.com/ggml-org/ci/tree/results/llama.cpp/15/fa07a5c564d3ed7e7eb64b73272cedb27e73ec/ggml-5-x86-amx-cc#summary
It won't run on this PR since ggml-ci runs only on branches in this repository. So the AMX CI will run after we merge the PR in master.
I've also sent you a collaborator invite, if you'd like you will be able to push branches in this repository and be able to run the CI prior to merging in the future.
Hi, just checking in - any progress to merging this to master?
It's already merged: https://github.com/ggml-org/llama.cpp/pull/8998
If I have a Sapphire Rapids processor which is AMX enabled, how do i ensure that I have them enabled in llama.cpp?
currently I am building it with
cmake -B build -DGGML_CUDA=ON -DGGML_RPC=ON
cmake --build build --config Release -j 56
should i use argument like this to add at the end?
-DGGML_USE_AMX=ON