llama.cpp icon indicating copy to clipboard operation
llama.cpp copied to clipboard

llamafile : improve sgemm.cpp

Open jart opened this issue 2 months ago • 3 comments

  • Re-enable by default
  • Fix issue described in #6716
  • Make code more abstract, elegant, and maintainable
  • Faster handling of weirdly shaped m an n edge cases

jart avatar Apr 20 '24 20:04 jart

📈 llama.cpp server for bench-server-baseline on Standard_NC4as_T4_v3 for phi-2-q4_0: 465 iterations 🚀

Expand details for performance related PR only
  • Concurrent users: 8, duration: 10m
  • HTTP request : avg=10185.65ms p(95)=27005.29ms fails=, finish reason: stop=416 truncated=49
  • Prompt processing (pp): avg=105.0tk/s p(95)=462.15tk/s
  • Token generation (tg): avg=25.75tk/s p(95)=36.86tk/s
  • ggml-org/models/phi-2/ggml-model-q4_0.gguf parallel=8 ctx-size=16384 ngl=33 batch-size=2048 ubatch-size=256 pp=1024 pp+tg=2048 branch=sgemm2 commit=d3d40bfd1e20d7c77029081c56188c3381a1a1b4

prompt_tokens_seconds

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 465 iterations"
    y-axis "llamacpp:prompt_tokens_seconds"
    x-axis "llamacpp:prompt_tokens_seconds" 1713727596 --> 1713728232
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 903.41, 903.41, 903.41, 903.41, 903.41, 655.86, 655.86, 655.86, 655.86, 655.86, 683.83, 683.83, 683.83, 683.83, 683.83, 703.98, 703.98, 703.98, 703.98, 703.98, 725.35, 725.35, 725.35, 725.35, 725.35, 754.87, 754.87, 754.87, 754.87, 754.87, 750.82, 750.82, 750.82, 750.82, 750.82, 760.52, 760.52, 760.52, 760.52, 760.52, 768.01, 768.01, 768.01, 768.01, 768.01, 764.59, 764.59, 764.59, 764.59, 764.59, 771.48, 771.48, 771.48, 771.48, 771.48, 766.39, 766.39, 766.39, 766.39, 766.39, 769.0, 769.0, 769.0, 769.0, 769.0, 785.47, 785.47, 785.47, 785.47, 785.47, 783.08, 783.08, 783.08, 783.08, 783.08, 750.46, 750.46, 750.46, 750.46, 750.46, 646.23, 646.23, 646.23, 646.23, 646.23, 646.67, 646.67, 646.67, 646.67, 646.67, 650.24, 650.24, 650.24, 650.24, 650.24, 650.22, 650.22, 650.22, 650.22, 650.22, 655.8, 655.8, 655.8, 655.8, 655.8, 663.75, 663.75, 663.75, 663.75, 663.75, 661.86, 661.86, 661.86, 661.86, 661.86, 662.94, 662.94, 662.94, 662.94, 662.94, 666.12, 666.12, 666.12, 666.12, 666.12, 666.5, 666.5, 666.5, 666.5, 666.5, 669.18, 669.18, 669.18, 669.18, 669.18, 670.52, 670.52, 670.52, 670.52, 670.52, 643.8, 643.8, 643.8, 643.8, 643.8, 647.81, 647.81, 647.81, 647.81, 647.81, 648.55, 648.55, 648.55, 648.55, 648.55, 647.84, 647.84, 647.84, 647.84, 647.84, 646.89, 646.89, 646.89, 646.89, 646.89, 648.43, 648.43, 648.43, 648.43, 648.43, 648.49, 648.49, 648.49, 648.49, 648.49, 651.7, 651.7, 651.7, 651.7, 651.7, 654.74, 654.74, 654.74, 654.74, 654.74, 654.25, 654.25, 654.25, 654.25, 654.25, 655.42, 655.42, 655.42, 655.42, 655.42, 658.0, 658.0, 658.0, 658.0, 658.0, 666.24, 666.24, 666.24, 666.24, 666.24, 670.04, 670.04, 670.04, 670.04, 670.04, 667.2, 667.2, 667.2, 667.2, 667.2, 666.66, 666.66, 666.66, 666.66, 666.66, 665.98, 665.98, 665.98, 665.98, 665.98, 666.02, 666.02, 666.02, 666.02, 666.02, 668.94, 668.94, 668.94, 668.94, 668.94, 671.72, 671.72, 671.72, 671.72, 671.72, 680.15, 680.15, 680.15, 680.15, 680.15, 673.62, 673.62, 673.62, 673.62, 673.62, 655.87, 655.87, 655.87, 655.87, 655.87, 654.24, 654.24, 654.24, 654.24, 654.24, 653.49, 653.49, 653.49, 653.49, 653.49, 652.47, 652.47, 652.47, 652.47, 652.47, 650.04, 650.04, 650.04, 650.04, 650.04, 652.53, 652.53, 652.53, 652.53, 652.53, 652.77, 652.77, 652.77, 652.77, 652.77, 655.36, 655.36, 655.36, 655.36, 655.36, 657.45, 657.45, 657.45, 657.45, 657.45, 658.24, 658.24, 658.24, 658.24, 658.24, 657.69, 657.69, 657.69, 657.69, 657.69, 657.69, 657.69, 657.69, 657.69]
                    
predicted_tokens_seconds
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 465 iterations"
    y-axis "llamacpp:predicted_tokens_seconds"
    x-axis "llamacpp:predicted_tokens_seconds" 1713727596 --> 1713728232
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 36.3, 36.3, 36.3, 36.3, 36.3, 22.22, 22.22, 22.22, 22.22, 22.22, 22.9, 22.9, 22.9, 22.9, 22.9, 24.05, 24.05, 24.05, 24.05, 24.05, 24.57, 24.57, 24.57, 24.57, 24.57, 24.7, 24.7, 24.7, 24.7, 24.7, 26.32, 26.32, 26.32, 26.32, 26.32, 26.81, 26.81, 26.81, 26.81, 26.81, 27.0, 27.0, 27.0, 27.0, 27.0, 26.82, 26.82, 26.82, 26.82, 26.82, 26.41, 26.41, 26.41, 26.41, 26.41, 26.1, 26.1, 26.1, 26.1, 26.1, 25.1, 25.1, 25.1, 25.1, 25.1, 24.45, 24.45, 24.45, 24.45, 24.45, 24.38, 24.38, 24.38, 24.38, 24.38, 23.83, 23.83, 23.83, 23.83, 23.83, 23.23, 23.23, 23.23, 23.23, 23.23, 22.94, 22.94, 22.94, 22.94, 22.94, 22.93, 22.93, 22.93, 22.93, 22.93, 22.99, 22.99, 22.99, 22.99, 22.99, 22.74, 22.74, 22.74, 22.74, 22.74, 22.5, 22.5, 22.5, 22.5, 22.5, 22.35, 22.35, 22.35, 22.35, 22.35, 22.06, 22.06, 22.06, 22.06, 22.06, 21.89, 21.89, 21.89, 21.89, 21.89, 21.92, 21.92, 21.92, 21.92, 21.92, 22.05, 22.05, 22.05, 22.05, 22.05, 22.1, 22.1, 22.1, 22.1, 22.1, 22.27, 22.27, 22.27, 22.27, 22.27, 22.43, 22.43, 22.43, 22.43, 22.43, 22.54, 22.54, 22.54, 22.54, 22.54, 22.36, 22.36, 22.36, 22.36, 22.36, 22.36, 22.36, 22.36, 22.36, 22.36, 22.38, 22.38, 22.38, 22.38, 22.38, 22.57, 22.57, 22.57, 22.57, 22.57, 22.63, 22.63, 22.63, 22.63, 22.63, 22.72, 22.72, 22.72, 22.72, 22.72, 22.92, 22.92, 22.92, 22.92, 22.92, 22.98, 22.98, 22.98, 22.98, 22.98, 22.99, 22.99, 22.99, 22.99, 22.99, 22.95, 22.95, 22.95, 22.95, 22.95, 22.87, 22.87, 22.87, 22.87, 22.87, 22.78, 22.78, 22.78, 22.78, 22.78, 22.72, 22.72, 22.72, 22.72, 22.72, 22.74, 22.74, 22.74, 22.74, 22.74, 22.75, 22.75, 22.75, 22.75, 22.75, 22.93, 22.93, 22.93, 22.93, 22.93, 22.97, 22.97, 22.97, 22.97, 22.97, 22.94, 22.94, 22.94, 22.94, 22.94, 22.6, 22.6, 22.6, 22.6, 22.6, 22.42, 22.42, 22.42, 22.42, 22.42, 22.34, 22.34, 22.34, 22.34, 22.34, 22.05, 22.05, 22.05, 22.05, 22.05, 21.71, 21.71, 21.71, 21.71, 21.71, 21.58, 21.58, 21.58, 21.58, 21.58, 21.61, 21.61, 21.61, 21.61, 21.61, 21.7, 21.7, 21.7, 21.7, 21.7, 21.71, 21.71, 21.71, 21.71, 21.71, 21.79, 21.79, 21.79, 21.79, 21.79, 21.84, 21.84, 21.84, 21.84, 21.84, 21.88, 21.88, 21.88, 21.88]
                    

Details

kv_cache_usage_ratio

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 465 iterations"
    y-axis "llamacpp:kv_cache_usage_ratio"
    x-axis "llamacpp:kv_cache_usage_ratio" 1713727596 --> 1713728232
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.01, 0.01, 0.01, 0.01, 0.01, 0.33, 0.33, 0.33, 0.33, 0.33, 0.23, 0.23, 0.23, 0.23, 0.23, 0.14, 0.14, 0.14, 0.14, 0.14, 0.12, 0.12, 0.12, 0.12, 0.12, 0.11, 0.11, 0.11, 0.11, 0.11, 0.12, 0.12, 0.12, 0.12, 0.12, 0.16, 0.16, 0.16, 0.16, 0.16, 0.13, 0.13, 0.13, 0.13, 0.13, 0.17, 0.17, 0.17, 0.17, 0.17, 0.2, 0.2, 0.2, 0.2, 0.2, 0.21, 0.21, 0.21, 0.21, 0.21, 0.21, 0.21, 0.21, 0.21, 0.21, 0.25, 0.25, 0.25, 0.25, 0.25, 0.18, 0.18, 0.18, 0.18, 0.18, 0.35, 0.35, 0.35, 0.35, 0.35, 0.28, 0.28, 0.28, 0.28, 0.28, 0.14, 0.14, 0.14, 0.14, 0.14, 0.18, 0.18, 0.18, 0.18, 0.18, 0.19, 0.19, 0.19, 0.19, 0.19, 0.17, 0.17, 0.17, 0.17, 0.17, 0.26, 0.26, 0.26, 0.26, 0.26, 0.25, 0.25, 0.25, 0.25, 0.25, 0.29, 0.29, 0.29, 0.29, 0.29, 0.21, 0.21, 0.21, 0.21, 0.21, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.16, 0.16, 0.16, 0.16, 0.16, 0.17, 0.17, 0.17, 0.17, 0.17, 0.11, 0.11, 0.11, 0.11, 0.11, 0.15, 0.15, 0.15, 0.15, 0.15, 0.26, 0.26, 0.26, 0.26, 0.26, 0.13, 0.13, 0.13, 0.13, 0.13, 0.17, 0.17, 0.17, 0.17, 0.17, 0.15, 0.15, 0.15, 0.15, 0.15, 0.1, 0.1, 0.1, 0.1, 0.1, 0.15, 0.15, 0.15, 0.15, 0.15, 0.08, 0.08, 0.08, 0.08, 0.08, 0.14, 0.14, 0.14, 0.14, 0.14, 0.13, 0.13, 0.13, 0.13, 0.13, 0.21, 0.21, 0.21, 0.21, 0.21, 0.2, 0.2, 0.2, 0.2, 0.2, 0.29, 0.29, 0.29, 0.29, 0.29, 0.15, 0.15, 0.15, 0.15, 0.15, 0.16, 0.16, 0.16, 0.16, 0.16, 0.13, 0.13, 0.13, 0.13, 0.13, 0.15, 0.15, 0.15, 0.15, 0.15, 0.11, 0.11, 0.11, 0.11, 0.11, 0.3, 0.3, 0.3, 0.3, 0.3, 0.42, 0.42, 0.42, 0.42, 0.42, 0.51, 0.51, 0.51, 0.51, 0.51, 0.4, 0.4, 0.4, 0.4, 0.4, 0.36, 0.36, 0.36, 0.36, 0.36, 0.39, 0.39, 0.39, 0.39, 0.39, 0.16, 0.16, 0.16, 0.16, 0.16, 0.14, 0.14, 0.14, 0.14, 0.14, 0.13, 0.13, 0.13, 0.13, 0.13, 0.19, 0.19, 0.19, 0.19, 0.19, 0.14, 0.14, 0.14, 0.14, 0.14, 0.1, 0.1, 0.1, 0.1, 0.1, 0.19, 0.19, 0.19, 0.19, 0.19, 0.27, 0.27, 0.27, 0.27]
                    
requests_processing
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 465 iterations"
    y-axis "llamacpp:requests_processing"
    x-axis "llamacpp:requests_processing" 1713727596 --> 1713728232
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 2.0, 2.0, 2.0, 2.0, 2.0, 4.0, 4.0, 4.0, 4.0, 4.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 2.0, 2.0, 2.0, 2.0, 2.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 3.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 3.0, 3.0, 3.0, 3.0, 3.0, 7.0, 7.0, 7.0, 7.0, 7.0, 3.0, 3.0, 3.0, 3.0, 3.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 3.0, 3.0, 3.0, 3.0, 3.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0]
                    

github-actions[bot] avatar Apr 20 '24 20:04 github-actions[bot]

@ggerganov Fixed. PTAL

Please give Q8_0 a try (build it with make -j32 LLAMA_NO_ACCELERATE=1 LLAMA_NO_METAL=1 main) On my Mac Studio M2 Ultra, Mistral 7b Q8 prompt eval goes from 90 tok/sec to 140 tok/sec with this change. That's 16% faster than Apple Accelerate cblas_sgemm() which goes 120 tok/sec for me. The other quants I've tried (q40 and f16) seem to be equal to Accelerate in speed.

Do you know if there's a better way to fix the Android issue? IIUC it's due to an instruction not being available on 32-bit ARM. Is there a way we could solve that with an #ifdef instead? I don't own a 32-bit ARM system, so I have no way of doing this myself.

jart avatar Apr 21 '24 19:04 jart

@ggerganov Could you advise me on how I might bring the benefits of llamafile_sgemm() to GGML_MUL_MAT_ID? I know very little about mixture of expert architecture. It's not obvious to me how I might go about decomposing that operation into 2d matrix multiplications.

jart avatar Apr 21 '24 22:04 jart

On master with Accelerate I get:

make clean && LLAMA_NO_METAL=1 make -j && ./llama-bench -m models/mistral-7b-v0.2/ggml-model-fp16.gguf -m models/mistral-7b-v0.2/ggml-model-q8_0.gguf -m models/mistral-7b-v0.2/ggml-model-q4_0.gguf -ngl 0 -n 0
model size params backend threads test t/s
llama 8B F16 13.49 GiB 7.24 B BLAS 16 pp 512 152.87 ± 1.06
llama 8B Q8_0 7.17 GiB 7.24 B BLAS 16 pp 512 147.44 ± 5.19
llama 8B Q4_0 3.83 GiB 7.24 B BLAS 16 pp 512 149.98 ± 1.63

build: 8960fe86 (2713)

With this PR without Accelerate:

make clean && LLAMA_NO_ACCELERATE=1 LLAMA_NO_METAL=1 make -j && ./llama-bench -m models/mistral-7b-v0.2/ggml-model-fp16.gguf -m models/mistral-7b-v0.2/ggml-model-q8_0.gguf -m models/mistral-7b-v0.2/ggml-model-q4_0.gguf -ngl 0 -n 0
model size params backend threads test t/s
llama 7B F16 13.49 GiB 7.24 B CPU 16 pp 512 172.84 ± 0.39
llama 7B Q8_0 7.17 GiB 7.24 B CPU 16 pp 512 146.22 ± 0.44
llama 7B Q4_0 3.83 GiB 7.24 B CPU 16 pp 512 123.81 ± 0.43

build: 6b220dca (2704)

So for me, F16 is faster now, Q8_0 is the same and Q4_0 is slower.

Btw, I've looked some more, and I think the proper call in ggml.c should be like this:

diff --git a/ggml.c b/ggml.c
index e3356bdb..086db96a 100644
--- a/ggml.c
+++ b/ggml.c
@@ -10878,15 +10878,13 @@ UseGgmlGemm1:;
     const size_t row_size = ggml_row_size(vec_dot_type, ne10);
 
 #if GGML_USE_LLAMAFILE
-    if (src1_cont) {
+    if (src1->type != vec_dot_type) {
         for (int64_t i13 = 0; i13 < ne13; i13++)
             for (int64_t i12 = 0; i12 < ne12; i12++)
                 if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
                                      (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
                                      nb01/ggml_type_size(src0->type),
-                                     (const char *)wdata + ggml_row_size(vec_dot_type,
-                                         nb12/ggml_type_size(src1->type)*i12 +
-                                         nb13/ggml_type_size(src1->type)*i13),
+                                     (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size,
                                      row_size/ggml_type_size(vec_dot_type),
                                      (char *)dst->data + i12*nb2 + i13*nb3,
                                      nb1/ggml_type_size(dst->type),

Regarding Android: in ggml-quants.c we do stuff like this to provide 32-bit ARM compatibility:

https://github.com/ggerganov/llama.cpp/blob/e931888d5024de814ce7119a18d6a959bfff3821/ggml-quants.c#L291-L307

We should not repeat the same implementation twice - we have to see if we can reuse it. I also don't have setup Android builds and it takes me some time to get the build running. So for now, let's focus on fixing the SGEMM and later we can think about improving Android support

Let me think some time about GGML_MUL_MAT_ID support

ggerganov avatar Apr 22 '24 14:04 ggerganov

Review comments addressed. PTAL. Agree on Android 32-bit.

Let me think some time about GGML_MUL_MAT_ID support

I studied the code for hours and managed to figure it out. I've got a llamafile_mixmul() function working now that enables mixtral to go 2x faster on my machine for prompt processing.

jart avatar Apr 22 '24 18:04 jart

These changes cause failed assertions when running Cohere's Command R+ model:

main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.
main: sgemm.cpp:827: bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int): Assertion `1ll * lda * m <= 0x7fffffff' failed.

Thread 22 "main" received signal SIGABRT, Aborted.
[Switching to Thread 0x7fcf41ffb640 (LWP 108365)]
__pthread_kill_implementation (no_tid=0, signo=6, threadid=140528142235200) at ./nptl/pthread_kill.c:44
44	./nptl/pthread_kill.c: No such file or directory.
(gdb) bt
#0  __pthread_kill_implementation (no_tid=0, signo=6, threadid=140528142235200) at ./nptl/pthread_kill.c:44
#1  __pthread_kill_internal (signo=6, threadid=140528142235200) at ./nptl/pthread_kill.c:78
#2  __GI___pthread_kill (threadid=140528142235200, signo=signo@entry=6) at ./nptl/pthread_kill.c:89
#3  0x00007ffff7a99476 in __GI_raise (sig=sig@entry=6) at ../sysdeps/posix/raise.c:26
#4  0x00007ffff7a7f7f3 in __GI_abort () at ./stdlib/abort.c:79
#5  0x00007ffff7a7f71b in __assert_fail_base (fmt=0x7ffff7c34130 "%s%s%s:%u: %s%sAssertion `%s' failed.\n%n", 
    assertion=0x555555811556 "1ll * lda * m <= 0x7fffffff", file=0x55555581150a "sgemm.cpp", line=827, 
    function=<optimized out>) at ./assert/assert.c:92
#6  0x00007ffff7a90e96 in __GI___assert_fail (assertion=0x555555811556 "1ll * lda * m <= 0x7fffffff", 
    file=0x55555581150a "sgemm.cpp", line=827, 
    function=0x555555811498 "bool llamafile_sgemm(int, int, int, const void*, int, const void*, int, void*, int, int, int, int, int, int, int)") at ./assert/assert.c:101
#7  0x000055555576c199 in llamafile_sgemm (m=256000, n=1, k=12288, A=0x7fcf9dd7b500, lda=12288, B=0x7fcf88918020, 
    ldb=12288, C=0x7fcf89518020, ldc=256000, ith=21, nth=32, task=0, Atype=1, Btype=0, Ctype=0) at sgemm.cpp:827
#8  0x0000555555588c78 in ggml_compute_forward_mul_mat (params=0x7fcf41ffae20, dst=0x55555bf087d0) at ggml.c:10831
#9  0x00005555555a1125 in ggml_compute_forward (params=0x7fcf41ffae20, tensor=0x55555bf087d0) at ggml.c:16254
#10 0x00005555555a75b0 in ggml_graph_compute_thread (data=0x7fffffffb660) at ggml.c:18398
#11 0x00007ffff7aebac3 in start_thread (arg=<optimized out>) at ./nptl/pthread_create.c:442
#12 0x00007ffff7b7d850 in clone3 () at ../sysdeps/unix/sysv/linux/x86_64/clone3.S:81

When I reverted commit 192090b the problem disappeared. Log files attached.

main-crash.log main-working.log

fairydreaming avatar Apr 25 '24 07:04 fairydreaming

Yes, this assert has to be avoided. The Command-R model has a very large output tensor and it's number of elements exceeds int. That's why in order to support it, we switched to int64_t in many places across the codebase: https://github.com/ggerganov/llama.cpp/pull/6491

This issue should be prioritized, @jart PTAL

ggerganov avatar Apr 25 '24 12:04 ggerganov

Here are instruction to trigger this assert:

  • clone https://huggingface.co/CohereForAI/c4ai-command-r-plus
# convert to GGUF
python3 convert-hf-to-gguf.py ~/Data/huggingface/c4ai-command-r-plus/ --outfile models/command-r-plus/ggml-model-f16.gguf --outtype f16

# quantize to Q8_0 + F16 token embeddings
make -j
./quantize --token-embedding-type f16 ./models/command-r-plus/ggml-model-f16.gguf ./models/command-r-plus/ggml-model-q8_0.gguf q8_0

# build in DEBUG and run
make clean
LLAMA_DEBUG=1 LLAMA_NO_METAL=1 LLAMA_NO_ACCELERATE=1 make -j
./main -m ./models/command-r-plus/ggml-model-q8_0.gguf

ggerganov avatar Apr 25 '24 18:04 ggerganov

I just tried a naive solution and replaced all ints in sgemm.cpp and sgemm.h with int64_t, and the resulting code works fine without any performance penalty (at least on my Epyc Genoa). Also there are no more crashes due to int overflow in pointer calculations when using Command R+.

By the way, @jart thank you for these changes, they improved the prompt eval time on my system by 65% on llama-3 70B Q8!

fairydreaming avatar Apr 26 '24 11:04 fairydreaming

Thanks for the inbox bump. Making this my top priority now. Expect a PR shortly.

jart avatar Apr 26 '24 13:04 jart