llama.cpp icon indicating copy to clipboard operation
llama.cpp copied to clipboard

Add Intel Advanced Matrix Extensions (AMX) support to ggml

Open mingfeima opened this issue 1 year ago • 24 comments

This PR improves intel server CPU performance with intel advanced matrix extension (AMX). AMX is a new built-in accelerator for gemm starting from 4th gen Xeon: https://www.intel.com/content/www/us/en/products/docs/accelerator-engines/advanced-matrix-extensions/overview.html

The basic idea is pretty much the same as what i have done in PyTorch https://github.com/pytorch/pytorch/pull/117475 for the int4 and int8 mixed dtype gemms.

Features

  • now it supports Q4_0, Q4_1, Q8_0 quantized format ( I picked up these formats based on current support from __ARM_FEATURE_MATMUL_INT8), more support will be added in the feature. Kernels are placed in ggml-amx.cpp since I don't want to mess up with ggml.c which is already very complexed and the amx kernels could also be more complexed in future if add more qformat support.
  • implement fast weight only quantized gemm kernels with amx, q4 will be unpacked to q8 and doing s8s8 or u8s8 gemm.
  • implement a fast path when batch dimension is 1 for gemv, when batch dimension is small vnni is usually faster than amx because amx has larger overhead.
  • implement double buffering for post processing (applying scales): the scales are stored in f16 in gguf and the tmul has to be interleaved with f32 instructions. So it won't hit the hardware compute limit as we can not run tmul in series. Applying double buffering here improves 8%-11% performance.
  • the amx and vnni kernels have numerically identical results with current avx or avx2 kernels from ggml-quant.c.
  • amx and kernels will be compiled automatically on CPUs with the hardware support, otherwise not.

Performance

  • results from llama2-7b-q4_0, about 2x speed up for the text generation. Collected on Intel (R) Xeon (R) CPU Max 9480:
  1. before
llama_print_timings:        load time =     533.79 ms
llama_print_timings:      sample time =       7.29 ms /   200 runs   (    0.04 ms per token, 27453.67 tokens per second)
llama_print_timings: prompt eval time =      77.35 ms /     6 tokens (   12.89 ms per token,    77.57 tokens per second)
llama_print_timings:        eval time =    9333.20 ms /   199 runs   (   46.90 ms per token,    21.32 tokens per second)
llama_print_timings:       total time =    9487.99 ms /   205 tokens
  1. after
llama_print_timings:        load time =     549.56 ms
llama_print_timings:      sample time =       3.73 ms /    96 runs   (    0.04 ms per token, 25751.07 tokens per second)
llama_print_timings: prompt eval time =      67.38 ms /     6 tokens (   11.23 ms per token,    89.05 tokens per second)
llama_print_timings:        eval time =    2245.79 ms /    95 runs   (   23.64 ms per token,    42.30 tokens per second)
llama_print_timings:       total time =    2346.99 ms /   101 tokens
  • results from benchmark-matmult (metic: gFlops):
cores before after speedup
1 50.67 260.04 5.13
4 171.92 1026.29 5.97
16 192.38 2143.1 11.14
32 263.7 3694.85 14.01

TODO:

  • ~add more quantized dtype support~
  • ~add bf16 gemm support with amx-bf16 (using avx512-bf16 for gemv)~
  • ~add f16 gemm support with amx-f16 (using avx512-f16 for gemv)~

I also noticed from vtune that some pointwise operators need additional optimization, e.g. softmax, etc. Will handle them later on.

mingfeima avatar Jun 03 '24 02:06 mingfeima

This PR also adds openmp support since the original pthead sync is done via atomic which has a very high overhead on server CPUs (and the sync has to be done very frequently for each operator launch). This is not my initial target but I have to fix it by using other threading runtimes, openmp or tbb. Otherwise the performance speedup will be cut off quite q bit.

I noticed https://github.com/ggerganov/llama.cpp/pull/7606 is also doing this, this should also work.

mingfeima avatar Jun 03 '24 02:06 mingfeima

BTW why AMX will greatly improve next token latency?

airMeng avatar Jun 03 '24 06:06 airMeng

📈 llama.cpp server for bench-server-baseline on Standard_NC4as_T4_v3 for phi-2-q4_0: 555 iterations 🚀

Expand details for performance related PR only
  • Concurrent users: 8, duration: 10m
  • HTTP request : avg=8404.11ms p(95)=20501.82ms fails=, finish reason: stop=500 truncated=55
  • Prompt processing (pp): avg=89.57tk/s p(95)=384.56tk/s
  • Token generation (tg): avg=34.36tk/s p(95)=47.48tk/s
  • ggml-org/models/phi-2/ggml-model-q4_0.gguf parallel=8 ctx-size=16384 ngl=33 batch-size=2048 ubatch-size=256 pp=1024 pp+tg=2048 branch=pr_add_amx_support_1 commit=952af436ea0c5717f06e701108f4b12b93c58260

prompt_tokens_seconds

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 555 iterations"
    y-axis "llamacpp:prompt_tokens_seconds"
    x-axis "llamacpp:prompt_tokens_seconds" 1717437224 --> 1717437854
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 667.98, 667.98, 667.98, 667.98, 667.98, 873.64, 873.64, 873.64, 873.64, 873.64, 852.92, 852.92, 852.92, 852.92, 852.92, 867.1, 867.1, 867.1, 867.1, 867.1, 918.08, 918.08, 918.08, 918.08, 918.08, 910.25, 910.25, 910.25, 910.25, 910.25, 901.32, 901.32, 901.32, 901.32, 901.32, 927.72, 927.72, 927.72, 927.72, 927.72, 918.49, 918.49, 918.49, 918.49, 918.49, 930.08, 930.08, 930.08, 930.08, 930.08, 943.96, 943.96, 943.96, 943.96, 943.96, 926.06, 926.06, 926.06, 926.06, 926.06, 941.81, 941.81, 941.81, 941.81, 941.81, 925.1, 925.1, 925.1, 925.1, 925.1, 860.14, 860.14, 860.14, 860.14, 860.14, 858.55, 858.55, 858.55, 858.55, 858.55, 859.54, 859.54, 859.54, 859.54, 859.54, 856.91, 856.91, 856.91, 856.91, 856.91, 876.7, 876.7, 876.7, 876.7, 876.7, 875.29, 875.29, 875.29, 875.29, 875.29, 870.65, 870.65, 870.65, 870.65, 870.65, 876.07, 876.07, 876.07, 876.07, 876.07, 876.06, 876.06, 876.06, 876.06, 876.06, 890.99, 890.99, 890.99, 890.99, 890.99, 891.11, 891.11, 891.11, 891.11, 891.11, 893.37, 893.37, 893.37, 893.37, 893.37, 904.6, 904.6, 904.6, 904.6, 904.6, 904.19, 904.19, 904.19, 904.19, 904.19, 904.11, 904.11, 904.11, 904.11, 904.11, 908.5, 908.5, 908.5, 908.5, 908.5, 907.48, 907.48, 907.48, 907.48, 907.48, 903.81, 903.81, 903.81, 903.81, 903.81, 903.97, 903.97, 903.97, 903.97, 903.97, 910.2, 910.2, 910.2, 910.2, 910.2, 914.47, 914.47, 914.47, 914.47, 914.47, 920.0, 920.0, 920.0, 920.0, 920.0, 920.72, 920.72, 920.72, 920.72, 920.72, 916.81, 916.81, 916.81, 916.81, 916.81, 918.08, 918.08, 918.08, 918.08, 918.08, 919.48, 919.48, 919.48, 919.48, 919.48, 920.75, 920.75, 920.75, 920.75, 920.75, 916.89, 916.89, 916.89, 916.89, 916.89, 916.14, 916.14, 916.14, 916.14, 916.14, 915.48, 915.48, 915.48, 915.48, 915.48, 912.62, 912.62, 912.62, 912.62, 912.62, 915.1, 915.1, 915.1, 915.1, 915.1, 914.85, 914.85, 914.85, 914.85, 914.85, 913.89, 913.89, 913.89, 913.89, 913.89, 913.84, 913.84, 913.84, 913.84, 913.84, 915.62, 915.62, 915.62, 915.62, 915.62, 917.67, 917.67, 917.67, 917.67, 917.67, 920.55, 920.55, 920.55, 920.55, 920.55, 925.42, 925.42, 925.42, 925.42, 925.42, 923.07, 923.07, 923.07, 923.07, 923.07, 923.82, 923.82, 923.82, 923.82, 923.82, 923.3, 923.3, 923.3, 923.3, 923.3, 920.75, 920.75, 920.75, 920.75, 920.75, 922.83, 922.83, 922.83, 922.83, 922.83, 924.12, 924.12, 924.12, 924.12, 924.12, 922.48, 922.48, 922.48, 922.48, 922.48, 922.29]
                    
predicted_tokens_seconds
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 555 iterations"
    y-axis "llamacpp:predicted_tokens_seconds"
    x-axis "llamacpp:predicted_tokens_seconds" 1717437224 --> 1717437854
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 32.83, 32.83, 32.83, 32.83, 32.83, 32.87, 32.87, 32.87, 32.87, 32.87, 30.98, 30.98, 30.98, 30.98, 30.98, 31.79, 31.79, 31.79, 31.79, 31.79, 32.29, 32.29, 32.29, 32.29, 32.29, 32.37, 32.37, 32.37, 32.37, 32.37, 33.56, 33.56, 33.56, 33.56, 33.56, 33.92, 33.92, 33.92, 33.92, 33.92, 34.12, 34.12, 34.12, 34.12, 34.12, 34.14, 34.14, 34.14, 34.14, 34.14, 33.81, 33.81, 33.81, 33.81, 33.81, 33.45, 33.45, 33.45, 33.45, 33.45, 32.91, 32.91, 32.91, 32.91, 32.91, 32.52, 32.52, 32.52, 32.52, 32.52, 31.97, 31.97, 31.97, 31.97, 31.97, 30.86, 30.86, 30.86, 30.86, 30.86, 30.1, 30.1, 30.1, 30.1, 30.1, 30.27, 30.27, 30.27, 30.27, 30.27, 30.35, 30.35, 30.35, 30.35, 30.35, 29.9, 29.9, 29.9, 29.9, 29.9, 29.72, 29.72, 29.72, 29.72, 29.72, 29.81, 29.81, 29.81, 29.81, 29.81, 30.06, 30.06, 30.06, 30.06, 30.06, 30.15, 30.15, 30.15, 30.15, 30.15, 30.3, 30.3, 30.3, 30.3, 30.3, 30.66, 30.66, 30.66, 30.66, 30.66, 30.52, 30.52, 30.52, 30.52, 30.52, 30.66, 30.66, 30.66, 30.66, 30.66, 30.97, 30.97, 30.97, 30.97, 30.97, 31.12, 31.12, 31.12, 31.12, 31.12, 31.22, 31.22, 31.22, 31.22, 31.22, 31.24, 31.24, 31.24, 31.24, 31.24, 31.32, 31.32, 31.32, 31.32, 31.32, 31.37, 31.37, 31.37, 31.37, 31.37, 31.15, 31.15, 31.15, 31.15, 31.15, 31.02, 31.02, 31.02, 31.02, 31.02, 30.77, 30.77, 30.77, 30.77, 30.77, 30.73, 30.73, 30.73, 30.73, 30.73, 30.85, 30.85, 30.85, 30.85, 30.85, 30.98, 30.98, 30.98, 30.98, 30.98, 31.17, 31.17, 31.17, 31.17, 31.17, 31.0, 31.0, 31.0, 31.0, 31.0, 30.93, 30.93, 30.93, 30.93, 30.93, 30.68, 30.68, 30.68, 30.68, 30.68, 29.38, 29.38, 29.38, 29.38, 29.38, 29.38, 29.38, 29.38, 29.38, 29.38, 29.36, 29.36, 29.36, 29.36, 29.36, 29.33, 29.33, 29.33, 29.33, 29.33, 29.38, 29.38, 29.38, 29.38, 29.38, 29.45, 29.45, 29.45, 29.45, 29.45, 29.54, 29.54, 29.54, 29.54, 29.54, 29.48, 29.48, 29.48, 29.48, 29.48, 29.48, 29.48, 29.48, 29.48, 29.48, 29.42, 29.42, 29.42, 29.42, 29.42, 29.55, 29.55, 29.55, 29.55, 29.55, 29.7, 29.7, 29.7, 29.7, 29.7, 29.8, 29.8, 29.8, 29.8, 29.8, 29.94, 29.94, 29.94, 29.94, 29.94, 29.95, 29.95, 29.95, 29.95, 29.95, 29.98, 29.98, 29.98, 29.98, 29.98, 29.97]
                    

Details

kv_cache_usage_ratio

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 555 iterations"
    y-axis "llamacpp:kv_cache_usage_ratio"
    x-axis "llamacpp:kv_cache_usage_ratio" 1717437224 --> 1717437854
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.27, 0.27, 0.27, 0.27, 0.27, 0.41, 0.41, 0.41, 0.41, 0.41, 0.12, 0.12, 0.12, 0.12, 0.12, 0.18, 0.18, 0.18, 0.18, 0.18, 0.22, 0.22, 0.22, 0.22, 0.22, 0.2, 0.2, 0.2, 0.2, 0.2, 0.11, 0.11, 0.11, 0.11, 0.11, 0.2, 0.2, 0.2, 0.2, 0.2, 0.13, 0.13, 0.13, 0.13, 0.13, 0.22, 0.22, 0.22, 0.22, 0.22, 0.31, 0.31, 0.31, 0.31, 0.31, 0.26, 0.26, 0.26, 0.26, 0.26, 0.24, 0.24, 0.24, 0.24, 0.24, 0.31, 0.31, 0.31, 0.31, 0.31, 0.36, 0.36, 0.36, 0.36, 0.36, 0.31, 0.31, 0.31, 0.31, 0.31, 0.14, 0.14, 0.14, 0.14, 0.14, 0.13, 0.13, 0.13, 0.13, 0.13, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.13, 0.13, 0.13, 0.13, 0.13, 0.21, 0.21, 0.21, 0.21, 0.21, 0.15, 0.15, 0.15, 0.15, 0.15, 0.28, 0.28, 0.28, 0.28, 0.28, 0.11, 0.11, 0.11, 0.11, 0.11, 0.14, 0.14, 0.14, 0.14, 0.14, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.15, 0.15, 0.15, 0.15, 0.15, 0.2, 0.2, 0.2, 0.2, 0.2, 0.15, 0.15, 0.15, 0.15, 0.15, 0.16, 0.16, 0.16, 0.16, 0.16, 0.3, 0.3, 0.3, 0.3, 0.3, 0.24, 0.24, 0.24, 0.24, 0.24, 0.28, 0.28, 0.28, 0.28, 0.28, 0.32, 0.32, 0.32, 0.32, 0.32, 0.15, 0.15, 0.15, 0.15, 0.15, 0.16, 0.16, 0.16, 0.16, 0.16, 0.12, 0.12, 0.12, 0.12, 0.12, 0.29, 0.29, 0.29, 0.29, 0.29, 0.47, 0.47, 0.47, 0.47, 0.47, 0.51, 0.51, 0.51, 0.51, 0.51, 0.5, 0.5, 0.5, 0.5, 0.5, 0.26, 0.26, 0.26, 0.26, 0.26, 0.25, 0.25, 0.25, 0.25, 0.25, 0.22, 0.22, 0.22, 0.22, 0.22, 0.22, 0.22, 0.22, 0.22, 0.22, 0.15, 0.15, 0.15, 0.15, 0.15, 0.09, 0.09, 0.09, 0.09, 0.09, 0.15, 0.15, 0.15, 0.15, 0.15, 0.21, 0.21, 0.21, 0.21, 0.21, 0.25, 0.25, 0.25, 0.25, 0.25, 0.13, 0.13, 0.13, 0.13, 0.13, 0.08, 0.08, 0.08, 0.08, 0.08, 0.11, 0.11, 0.11, 0.11, 0.11, 0.07, 0.07, 0.07, 0.07, 0.07, 0.15, 0.15, 0.15, 0.15, 0.15, 0.19, 0.19, 0.19, 0.19, 0.19, 0.22, 0.22, 0.22, 0.22, 0.22, 0.28]
                    
requests_processing
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 555 iterations"
    y-axis "llamacpp:requests_processing"
    x-axis "llamacpp:requests_processing" 1717437224 --> 1717437854
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 3.0, 3.0, 3.0, 3.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 2.0, 2.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0, 2.0, 2.0, 2.0, 2.0, 2.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 2.0, 2.0, 2.0, 2.0, 2.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 2.0, 2.0, 2.0, 2.0, 2.0, 6.0, 6.0, 6.0, 6.0, 6.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 3.0, 3.0, 3.0, 3.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0, 5.0, 2.0, 2.0, 2.0, 2.0, 2.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 6.0, 6.0, 6.0, 6.0, 6.0, 4.0, 4.0, 4.0, 4.0, 4.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 1.0]
                    

github-actions[bot] avatar Jun 03 '24 18:06 github-actions[bot]

Here is my suggestion:

  1. update the README.md Explain the condition that AMX will be used to speed up inference, like hardware, build parameter. How to run CI

  2. CMake (optional) Most backends use CMake to build. It's great if support CMake to build it.

NeoZhangJianyu avatar Jun 04 '24 05:06 NeoZhangJianyu

BTW why AMX will greatly improve next token latency?

I also wrote an vnni kennel for gemv cases.

mingfeima avatar Jun 06 '24 02:06 mingfeima

Here is my suggestion:

  1. update the README.md Explain the condition that AMX will be used to speed up inference, like hardware, build parameter. How to run CI
  2. CMake (optional) Most backends use CMake to build. It's great if support CMake to build it.

sure, the BKMs for intel also need to be updated.

mingfeima avatar Jun 06 '24 02:06 mingfeima

Updates: f16 support added

Right now this patch only has a avx512 kernel which is doing fma with avx512f (did not use avx512-fp16 here as _mm512_fmadd_ph is doing accumulation with 16 bits). The amx kernel will be added once the 6th gen Xeon is released which has amx-f16 support).

Also postpone bf16 amx kernels support to align with f16 amx timeline. Since bf16 is not that common is gguf.

Performance: Tested on Meta-Llama-3-8B-Instruct-fp16.gguf, about 1.6x performance improvement in test generation, tested on Intel (R) Xeon (R) CPU Max 9480 (the 4th gen Xeon):

### before:
llama_print_timings:        eval time =   20867.58 ms /   199 runs   (  104.86 ms per token,     9.54 tokens per second)

### after:
llama_print_timings:        eval time =   13214.77 ms /   199 runs   (   66.41 ms per token,    15.06 tokens per second)

mingfeima avatar Jun 17 '24 13:06 mingfeima

Now that we have the necessary capabilities to implement backends such as the BLAS backend (#6210), would it make sense to implement a similar AMX backend and put this implementation there. @slaren what do you think?

ggerganov avatar Jun 18 '24 08:06 ggerganov

I experimented with this in https://github.com/ggerganov/llama.cpp/commit/f3974cabac841b9b35283a731bc31203288633d4 by moving all matrix multiplication to the BLAS backend. Generally I think the performance is ok, maybe 1-2% slower for small models (<1B), but the difference is very small for most models.

slaren avatar Jun 18 '24 11:06 slaren

I think the implementation is good! Suggest to update the README.md :

  1. guide how to enable this feature.
  2. the condition to enable it.
  3. Relationship with AVX512, VNNI and AMX.

There is only miss the guide of this feature.

Thank you!

NeoZhangJianyu avatar Jun 18 '24 12:06 NeoZhangJianyu

I think it would be better to leave the implementation as is instead of moving it to a different backend, the performance would be slightly better, and I don't really see a good reason to split the CPU backend into multiple backends. The changes in ggml.c should not be necessary now that openmp is already used by default.

slaren avatar Jun 18 '24 17:06 slaren

I think it would be better to leave the implementation as is instead of moving it to a different backend, the performance would be slightly better, and I don't really see a good reason to split the CPU backend into multiple backends. The changes in ggml.c should not be necessary now that openmp is already used by default.

AMX is a new built-in accelerator available from the 4th generation of Xeon, the Intel sever CPU, link. So this PR is actually trying to improve the performance of llama.cpp on intel server CPUs. And AMX is not equal to the concept of BLAS.

@slaren I don't quite get your idea, should I continue with ggml-amx.cpp or move the optimizations to somewhere else? My general idea is putting all the AMX related optimizations in a single file which would be easier to maintain. The current available Xeons (the 4th gen and the 5th gen) have the same ISA, but the 6th gen Xeon has two different types: E core and P core. 6th gen Xeon will be launched very soon, so I need to update the AMX related optimizations for the new hardware in near future: adding amx-f16 kernels.

OMP changes in ggml.c shall be gone after rebasing. Currently I am working on QK_K AMX kernels and I will clear up this PR once it is done.

mingfeima avatar Jun 19 '24 03:06 mingfeima

should I continue with ggml-amx.cpp or move the optimizations to somewhere else?

I was responding to @ggerganov suggestion to move the implementation to a different backend similar to the BLAS backend. I think you should continue as is.

slaren avatar Jun 19 '24 05:06 slaren

Ok, let's proceed as is

ggerganov avatar Jun 19 '24 08:06 ggerganov

Is there any progress? I am really looking forward to the AMX support.

nai-kon avatar Jul 04 '24 05:07 nai-kon

Is there any progress? I am really looking forward to the AMX support.

Recently I got distracted by some other tasks, I use my spare time to work on this project as this is not an official task from my employer. Currently I am working on the Q4K quant format, have to say that it is much more complexed... Anyway it's about to be finished.

mingfeima avatar Jul 05 '24 05:07 mingfeima

Added AMX and VNNI kernels for Q4_K, Q5_K, Q6_K, IQ4_XS.

mingfeima avatar Jul 18 '24 07:07 mingfeima

I like that the code is very well isolated from the rest of the codebase. Haven't reviewed the mmq.cpp source in details yet and it would be difficult without the appropriate hardware, but I think that's alright as we can easily determine that there won't be side effects to the rest of the codebase.

Wondering how we could add some tests for this functionality.

Overall, seems good to me. @slaren What do you think?

Yeah... the CI is a big problem. I will try to find some internal sponsor and then we can use our company cloud, that would be the best. Otherwise, we will have to go the emulator.

@ggerganov i was wondering how the Ascend910B3 functionalities are tested in the CI ? Does Huawei provides the CI support ? And also how about the aarch64 functionalities for arm servers ?

mingfeima avatar Jul 25 '24 05:07 mingfeima

We don't have CI for the CANN backend either. For aarch64, I'm planning to try to rent an Arm machine on the Azure cloud when they become available and if they are not too expensive

ggerganov avatar Jul 25 '24 06:07 ggerganov

I think the Sapphire Rapids Xeon (4th generation Xeon) support AMX. In Azure, DCesv5-series and DCedsv5-series are powered by Intel® 4th Generation Xeon® Scalable processors (https://learn.microsoft.com/en-us/azure/virtual-machines/dcesv5-dcedsv5-series). They should support AMX. It possible to build CI on it.

NeoZhangJianyu avatar Jul 29 '24 05:07 NeoZhangJianyu

Hi, I noticed some quantization issues in mmq.cpp. https://github.com/mingfeima/llama.cpp/blob/74bb1eb52be7d9b9eb484d156d24a474dd09f278/ggml/src/ggml-amx/mmq.cpp#L1183-L1195 Here, we are using a single scale vd0 for all 16x32 weights. However, Q8_0 uses scale parameter per blck_size=32 elements.

  1. Is this compensated by agreeing with a single scale parameter for all 16x32 weights? I don't see any code in pack_B, etc doing so.
  2. Wouldn't this be another quantization method if we use a single scale for 512 weights? This would have different result compared with existing Q8_0 AVX-based methods.

gyusang avatar Jul 31 '24 07:07 gyusang

Hi, I noticed some quantization issues in mmq.cpp. https://github.com/mingfeima/llama.cpp/blob/74bb1eb52be7d9b9eb484d156d24a474dd09f278/ggml/src/ggml-amx/mmq.cpp#L1183-L1195 Here, we are using a single scale vd0 for all 16x32 weights. However, Q8_0 uses scale parameter per blck_size=32 elements.

  1. Is this compensated by agreeing with a single scale parameter for all 16x32 weights? I don't see any code in pack_B, etc doing so.
  2. Wouldn't this be another quantization method if we use a single scale for 512 weights? This would have different result compared with existing Q8_0 AVX-based methods.

the weight packing for Q8_0 is here https://github.com/mingfeima/llama.cpp/blob/74bb1eb52be7d9b9eb484d156d24a474dd09f278/ggml/src/ggml-amx/mmq.cpp#L866-L873

each weight block of 16x32 (NxK) is stored in the format of (KxN) so that we can do FMA here, and this block will have 16 scales (d0), it is packed as a contiguous vector of 1x16, the dtype is f16. So to sum up, the scale is a 256-bit vector which corresponds to 16 columns. So it is not a "single scale parameter for all 16x32 weights". If the computation is wrong, the llm will talk like crazy.

mingfeima avatar Aug 05 '24 04:08 mingfeima

@ggerganov On Azure, DCesv5 and ECesv5 instances have intel AMX support, they are all 4th gen Xeon (codename Sapphire Rapids): https://azure.microsoft.com/en-us/updates/confidential-vms-with-intel-tdx-dcesv5-ecesv5/ https://learn.microsoft.com/en-us/azure/virtual-machines/sizes/general-purpose/dcesv5-series?tabs=sizebasic

Is that possible to use those instances for CI ? The amx features will be compiled on CPUs with AMX support by default, the requirement for minimum gcc version is gcc-11.

mingfeima avatar Aug 08 '24 05:08 mingfeima

Thanks for letting me know - I just added an AMX VM (EC8eds v5) to the ggml-ci fleet:

2024-08-08 10 24 49 github com 0f97b4e43c09

https://github.com/ggml-org/ci/tree/results/llama.cpp/15/fa07a5c564d3ed7e7eb64b73272cedb27e73ec/ggml-5-x86-amx-cc#summary

It won't run on this PR since ggml-ci runs only on branches in this repository. So the AMX CI will run after we merge the PR in master.

I've also sent you a collaborator invite, if you'd like you will be able to push branches in this repository and be able to run the CI prior to merging in the future.

ggerganov avatar Aug 08 '24 07:08 ggerganov

Hi, just checking in - any progress to merging this to master?

artun42 avatar Mar 05 '25 22:03 artun42

It's already merged: https://github.com/ggml-org/llama.cpp/pull/8998

ggerganov avatar Mar 06 '25 05:03 ggerganov

If I have a Sapphire Rapids processor which is AMX enabled, how do i ensure that I have them enabled in llama.cpp?

currently I am building it with

cmake -B build -DGGML_CUDA=ON -DGGML_RPC=ON
cmake --build build --config Release -j 56

should i use argument like this to add at the end?

-DGGML_USE_AMX=ON

mtcl avatar Jun 15 '25 04:06 mtcl