llama.cpp icon indicating copy to clipboard operation
llama.cpp copied to clipboard

ggml : rewrite silu and softmax for cpu

Open jart opened this issue 9 months ago • 5 comments

This change upstreams llamafile's vectorized expf() functions. This lets us compute softmax and silu more accurately than the short[65536] lookup table that GGML previously used to make this operation go faster. We can support aarch64 and sse2+ with the worst case rounding error of 2 ulp. I wrote avx2 and avx512 implementations as well but they didn't offer much advantage compared to sse2+fma to be worth the code complexity.

jart avatar May 09 '24 00:05 jart

📈 llama.cpp server for bench-server-baseline on Standard_NC4as_T4_v3 for phi-2-q4_0: 543 iterations 🚀

Expand details for performance related PR only
  • Concurrent users: 8, duration: 10m
  • HTTP request : avg=8626.19ms p(95)=21696.44ms fails=, finish reason: stop=474 truncated=69
  • Prompt processing (pp): avg=94.59tk/s p(95)=412.43tk/s
  • Token generation (tg): avg=33.43tk/s p(95)=48.33tk/s
  • ggml-org/models/phi-2/ggml-model-q4_0.gguf parallel=8 ctx-size=16384 ngl=33 batch-size=2048 ubatch-size=256 pp=1024 pp+tg=2048 branch=expf commit=d7359a389c236193edac1c8761e6ac98844654f3

prompt_tokens_seconds

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 543 iterations"
    y-axis "llamacpp:prompt_tokens_seconds"
    x-axis "llamacpp:prompt_tokens_seconds" 1715376005 --> 1715376631
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 676.15, 676.15, 676.15, 676.15, 676.15, 693.38, 693.38, 693.38, 693.38, 693.38, 686.03, 686.03, 686.03, 686.03, 686.03, 716.71, 716.71, 716.71, 716.71, 716.71, 787.67, 787.67, 787.67, 787.67, 787.67, 798.67, 798.67, 798.67, 798.67, 798.67, 798.41, 798.41, 798.41, 798.41, 798.41, 816.18, 816.18, 816.18, 816.18, 816.18, 816.66, 816.66, 816.66, 816.66, 816.66, 826.24, 826.24, 826.24, 826.24, 826.24, 827.91, 827.91, 827.91, 827.91, 827.91, 839.83, 839.83, 839.83, 839.83, 839.83, 845.37, 845.37, 845.37, 845.37, 845.37, 891.54, 891.54, 891.54, 891.54, 891.54, 896.52, 896.52, 896.52, 896.52, 896.52, 898.39, 898.39, 898.39, 898.39, 898.39, 896.16, 896.16, 896.16, 896.16, 896.16, 909.86, 909.86, 909.86, 909.86, 909.86, 901.74, 901.74, 901.74, 901.74, 901.74, 898.93, 898.93, 898.93, 898.93, 898.93, 900.17, 900.17, 900.17, 900.17, 900.17, 901.19, 901.19, 901.19, 901.19, 901.19, 901.37, 901.37, 901.37, 901.37, 901.37, 914.57, 914.57, 914.57, 914.57, 914.57, 913.27, 913.27, 913.27, 913.27, 913.27, 914.12, 914.12, 914.12, 914.12, 914.12, 884.7, 884.7, 884.7, 884.7, 884.7, 880.58, 880.58, 880.58, 880.58, 880.58, 874.62, 874.62, 874.62, 874.62, 874.62, 874.44, 874.44, 874.44, 874.44, 874.44, 878.93, 878.93, 878.93, 878.93, 878.93, 876.59, 876.59, 876.59, 876.59, 876.59, 879.89, 879.89, 879.89, 879.89, 879.89, 889.29, 889.29, 889.29, 889.29, 889.29, 896.06, 896.06, 896.06, 896.06, 896.06, 895.27, 895.27, 895.27, 895.27, 895.27, 898.07, 898.07, 898.07, 898.07, 898.07, 895.61, 895.61, 895.61, 895.61, 895.61, 898.03, 898.03, 898.03, 898.03, 898.03, 900.02, 900.02, 900.02, 900.02, 900.02, 903.55, 903.55, 903.55, 903.55, 903.55, 912.38, 912.38, 912.38, 912.38, 912.38, 913.02, 913.02, 913.02, 913.02, 913.02, 909.18, 909.18, 909.18, 909.18, 909.18, 908.34, 908.34, 908.34, 908.34, 908.34, 904.61, 904.61, 904.61, 904.61, 904.61, 904.91, 904.91, 904.91, 904.91, 904.91, 909.01, 909.01, 909.01, 909.01, 909.01, 908.42, 908.42, 908.42, 908.42, 908.42, 913.16, 913.16, 913.16, 913.16, 913.16, 912.15, 912.15, 912.15, 912.15, 912.15, 914.4, 914.4, 914.4, 914.4, 914.4, 917.57, 917.57, 917.57, 917.57, 917.57, 915.58, 915.58, 915.58, 915.58, 915.58, 920.75, 920.75, 920.75, 920.75, 920.75, 919.24, 919.24, 919.24, 919.24, 919.24, 920.07, 920.07, 920.07, 920.07, 920.07, 918.79, 918.79, 918.79, 918.79, 918.79, 917.24, 917.24, 917.24, 917.24, 917.24, 918.44, 918.44, 918.44, 918.44, 918.44, 918.61, 918.61, 918.61, 918.61]
                    
predicted_tokens_seconds
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 543 iterations"
    y-axis "llamacpp:predicted_tokens_seconds"
    x-axis "llamacpp:predicted_tokens_seconds" 1715376005 --> 1715376631
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 41.33, 41.33, 41.33, 41.33, 41.33, 35.68, 35.68, 35.68, 35.68, 35.68, 29.47, 29.47, 29.47, 29.47, 29.47, 28.84, 28.84, 28.84, 28.84, 28.84, 30.64, 30.64, 30.64, 30.64, 30.64, 31.13, 31.13, 31.13, 31.13, 31.13, 32.39, 32.39, 32.39, 32.39, 32.39, 33.65, 33.65, 33.65, 33.65, 33.65, 33.61, 33.61, 33.61, 33.61, 33.61, 33.73, 33.73, 33.73, 33.73, 33.73, 33.4, 33.4, 33.4, 33.4, 33.4, 33.78, 33.78, 33.78, 33.78, 33.78, 33.62, 33.62, 33.62, 33.62, 33.62, 32.91, 32.91, 32.91, 32.91, 32.91, 32.27, 32.27, 32.27, 32.27, 32.27, 32.39, 32.39, 32.39, 32.39, 32.39, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5, 32.07, 32.07, 32.07, 32.07, 32.07, 31.93, 31.93, 31.93, 31.93, 31.93, 31.67, 31.67, 31.67, 31.67, 31.67, 31.58, 31.58, 31.58, 31.58, 31.58, 31.79, 31.79, 31.79, 31.79, 31.79, 31.57, 31.57, 31.57, 31.57, 31.57, 31.78, 31.78, 31.78, 31.78, 31.78, 32.01, 32.01, 32.01, 32.01, 32.01, 32.02, 32.02, 32.02, 32.02, 32.02, 31.52, 31.52, 31.52, 31.52, 31.52, 31.35, 31.35, 31.35, 31.35, 31.35, 31.45, 31.45, 31.45, 31.45, 31.45, 31.65, 31.65, 31.65, 31.65, 31.65, 31.8, 31.8, 31.8, 31.8, 31.8, 32.01, 32.01, 32.01, 32.01, 32.01, 32.12, 32.12, 32.12, 32.12, 32.12, 32.05, 32.05, 32.05, 32.05, 32.05, 31.82, 31.82, 31.82, 31.82, 31.82, 31.67, 31.67, 31.67, 31.67, 31.67, 31.73, 31.73, 31.73, 31.73, 31.73, 31.87, 31.87, 31.87, 31.87, 31.87, 31.99, 31.99, 31.99, 31.99, 31.99, 32.1, 32.1, 32.1, 32.1, 32.1, 32.02, 32.02, 32.02, 32.02, 32.02, 31.97, 31.97, 31.97, 31.97, 31.97, 31.31, 31.31, 31.31, 31.31, 31.31, 30.76, 30.76, 30.76, 30.76, 30.76, 30.0, 30.0, 30.0, 30.0, 30.0, 29.71, 29.71, 29.71, 29.71, 29.71, 29.65, 29.65, 29.65, 29.65, 29.65, 29.82, 29.82, 29.82, 29.82, 29.82, 29.85, 29.85, 29.85, 29.85, 29.85, 29.95, 29.95, 29.95, 29.95, 29.95, 29.98, 29.98, 29.98, 29.98, 29.98, 30.01, 30.01, 30.01, 30.01, 30.01, 29.85, 29.85, 29.85, 29.85, 29.85, 29.78, 29.78, 29.78, 29.78, 29.78, 29.74, 29.74, 29.74, 29.74, 29.74, 29.88, 29.88, 29.88, 29.88, 29.88, 30.01, 30.01, 30.01, 30.01, 30.01, 30.1, 30.1, 30.1, 30.1, 30.1, 30.18, 30.18, 30.18, 30.18, 30.18, 30.28, 30.28, 30.28, 30.28]
                    

Details

kv_cache_usage_ratio

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 543 iterations"
    y-axis "llamacpp:kv_cache_usage_ratio"
    x-axis "llamacpp:kv_cache_usage_ratio" 1715376005 --> 1715376631
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.24, 0.24, 0.24, 0.24, 0.24, 0.38, 0.38, 0.38, 0.38, 0.38, 0.23, 0.23, 0.23, 0.23, 0.23, 0.12, 0.12, 0.12, 0.12, 0.12, 0.21, 0.21, 0.21, 0.21, 0.21, 0.11, 0.11, 0.11, 0.11, 0.11, 0.13, 0.13, 0.13, 0.13, 0.13, 0.15, 0.15, 0.15, 0.15, 0.15, 0.18, 0.18, 0.18, 0.18, 0.18, 0.22, 0.22, 0.22, 0.22, 0.22, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.26, 0.26, 0.26, 0.26, 0.26, 0.32, 0.32, 0.32, 0.32, 0.32, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.18, 0.18, 0.18, 0.18, 0.18, 0.3, 0.3, 0.3, 0.3, 0.3, 0.28, 0.28, 0.28, 0.28, 0.28, 0.32, 0.32, 0.32, 0.32, 0.32, 0.21, 0.21, 0.21, 0.21, 0.21, 0.17, 0.17, 0.17, 0.17, 0.17, 0.15, 0.15, 0.15, 0.15, 0.15, 0.14, 0.14, 0.14, 0.14, 0.14, 0.12, 0.12, 0.12, 0.12, 0.12, 0.2, 0.2, 0.2, 0.2, 0.2, 0.31, 0.31, 0.31, 0.31, 0.31, 0.23, 0.23, 0.23, 0.23, 0.23, 0.16, 0.16, 0.16, 0.16, 0.16, 0.15, 0.15, 0.15, 0.15, 0.15, 0.11, 0.11, 0.11, 0.11, 0.11, 0.13, 0.13, 0.13, 0.13, 0.13, 0.17, 0.17, 0.17, 0.17, 0.17, 0.23, 0.23, 0.23, 0.23, 0.23, 0.21, 0.21, 0.21, 0.21, 0.21, 0.19, 0.19, 0.19, 0.19, 0.19, 0.16, 0.16, 0.16, 0.16, 0.16, 0.15, 0.15, 0.15, 0.15, 0.15, 0.14, 0.14, 0.14, 0.14, 0.14, 0.09, 0.09, 0.09, 0.09, 0.09, 0.25, 0.25, 0.25, 0.25, 0.25, 0.44, 0.44, 0.44, 0.44, 0.44, 0.54, 0.54, 0.54, 0.54, 0.54, 0.62, 0.62, 0.62, 0.62, 0.62, 0.6, 0.6, 0.6, 0.6, 0.6, 0.29, 0.29, 0.29, 0.29, 0.29, 0.14, 0.14, 0.14, 0.14, 0.14, 0.15, 0.15, 0.15, 0.15, 0.15, 0.12, 0.12, 0.12, 0.12, 0.12, 0.17, 0.17, 0.17, 0.17, 0.17, 0.11, 0.11, 0.11, 0.11, 0.11, 0.17, 0.17, 0.17, 0.17, 0.17, 0.31, 0.31, 0.31, 0.31, 0.31, 0.23, 0.23, 0.23, 0.23, 0.23, 0.25, 0.25, 0.25, 0.25, 0.25, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.14, 0.14, 0.14, 0.14, 0.14, 0.11, 0.11, 0.11, 0.11, 0.11, 0.17, 0.17, 0.17, 0.17]
                    
requests_processing
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 543 iterations"
    y-axis "llamacpp:requests_processing"
    x-axis "llamacpp:requests_processing" 1715376005 --> 1715376631
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 2.0, 2.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 3.0, 6.0, 6.0, 6.0, 6.0, 6.0, 3.0, 3.0, 3.0, 3.0, 3.0, 8.0, 8.0, 8.0, 8.0, 8.0, 2.0, 2.0, 2.0, 2.0, 2.0, 6.0, 6.0, 6.0, 6.0, 6.0, 4.0, 4.0, 4.0, 4.0, 4.0, 2.0, 2.0, 2.0, 2.0, 2.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 4.0, 4.0, 4.0, 4.0, 4.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 1.0, 1.0, 1.0, 1.0, 1.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 2.0, 2.0, 2.0, 2.0, 2.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0]
                    

github-actions[bot] avatar May 09 '24 01:05 github-actions[bot]

Not deeply analysing the changes but these are the general observation if it would help other reviewers:

  • Commented out #define removed
  • Extracted 5 duplicated lines into ggml_vec_soft_max_f32()
  • Various functions relating to GGML_SILU_FP16 removed
  • ggml_v_expf() added
  • ggml_v_silu() added
  • ggml_vec_silu_f32() adjusted with preprocessor statement to adjust function based on SSE2 or __ARM_NEON flag
  • there are other changes... but these are the main things i noticed anyway...

mofosyne avatar May 09 '24 08:05 mofosyne

On AMD Ryzen 9 5950X and M2 Ultra SOFT_MAX is about ~1.5x faster than master

Using the following command to benchmark:

make -j tests && ./tests/test-backend-ops -o SOFT_MAX -b CPU perf

ggerganov avatar May 09 '24 10:05 ggerganov

I'm glad to hear that. Here's the avx2 and avx512 variations if you want to try them out:

inline __m256 llamafile_expf_avx2(__m256 x) {
  const __m256 r = _mm256_set1_ps(0x1.8p23f);
  const __m256 z = MADD256(x, _mm256_set1_ps(0x1.715476p+0f), r);
  const __m256 n = _mm256_sub_ps(z, r);
  const __m256 b = NMADD256(n, _mm256_set1_ps(0x1.7f7d1cp-20f),
                            NMADD256(n, _mm256_set1_ps(0x1.62e4p-1f), x));
  const __m256i e = _mm256_slli_epi32(_mm256_castps_si256(z), 23);
  const __m256 k = _mm256_castsi256_ps(
      _mm256_add_epi32(e, _mm256_castps_si256(_mm256_set1_ps(1))));
  const __m256i c = _mm256_castps_si256(
      _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
                    _mm256_set1_ps(126), _CMP_GT_OQ));
  const __m256 u = _mm256_mul_ps(b, b);
  const __m256 j = MADD256(MADD256(MADD256(_mm256_set1_ps(0x1.0e4020p-7f), b,
                                           _mm256_set1_ps(0x1.573e2ep-5f)),
                                   u,
                                   MADD256(_mm256_set1_ps(0x1.555e66p-3f), b,
                                           _mm256_set1_ps(0x1.fffdb6p-2f))),
                           u, _mm256_mul_ps(_mm256_set1_ps(0x1.ffffecp-1f), b));
  if (!_mm256_movemask_ps(_mm256_castsi256_ps(c)))
    return MADD256(j, k, k);
  const __m256i g = _mm256_and_si256(
      _mm256_castps_si256(_mm256_cmp_ps(n, _mm256_setzero_ps(), _CMP_LE_OQ)),
      _mm256_set1_epi32(0x82000000u));
  const __m256 s1 =
      _mm256_castsi256_ps(_mm256_add_epi32(g, _mm256_set1_epi32(0x7f000000u)));
  const __m256 s2 = _mm256_castsi256_ps(_mm256_sub_epi32(e, g));
  const __m256i d = _mm256_castps_si256(
      _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
                    _mm256_set1_ps(192), _CMP_GT_OQ));
  return _mm256_or_ps(
      _mm256_and_ps(_mm256_castsi256_ps(d), _mm256_mul_ps(s1, s1)),
      _mm256_andnot_ps(
          _mm256_castsi256_ps(d),
          _mm256_or_ps(
              _mm256_and_ps(_mm256_castsi256_ps(c),
                            _mm256_mul_ps(MADD256(s2, j, s2), s1)),
              _mm256_andnot_ps(_mm256_castsi256_ps(c), MADD256(k, j, k)))));
}

inline __m512 llamafile_expf_avx512(__m512 x) {
  const __m512 r = _mm512_set1_ps(0x1.8p23f);
  const __m512 z = MADD512(x, _mm512_set1_ps(0x1.715476p+0f), r);
  const __m512 n = _mm512_sub_ps(z, r);
  const __m512 b = NMADD512(n, _mm512_set1_ps(0x1.7f7d1cp-20f),
                            NMADD512(n, _mm512_set1_ps(0x1.62e4p-1f), x));
  const __m512i e = _mm512_slli_epi32(_mm512_castps_si512(z), 23);
  const __m512 k = _mm512_castsi512_ps(
      _mm512_add_epi32(e, _mm512_castps_si512(_mm512_set1_ps(1))));
  const __mmask16 c =
      _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(126), _CMP_GT_OQ);
  const __m512 u = _mm512_mul_ps(b, b);
  const __m512 j = MADD512(MADD512(MADD512(_mm512_set1_ps(0x1.0e4020p-7f), b,
                                           _mm512_set1_ps(0x1.573e2ep-5f)),
                                   u,
                                   MADD512(_mm512_set1_ps(0x1.555e66p-3f), b,
                                           _mm512_set1_ps(0x1.fffdb6p-2f))),
                           u, _mm512_mul_ps(_mm512_set1_ps(0x1.ffffecp-1f), b));
  if (_mm512_kortestz(c, c))
    return MADD512(j, k, k);
  const __m512i g = _mm512_and_si512(
      _mm512_movm_epi32(_mm512_cmp_ps_mask(n, _mm512_setzero_ps(), _CMP_LE_OQ)),
      _mm512_set1_epi32(0x82000000u));
  const __m512 s1 =
      _mm512_castsi512_ps(_mm512_add_epi32(g, _mm512_set1_epi32(0x7f000000u)));
  const __m512 s2 = _mm512_castsi512_ps(_mm512_sub_epi32(e, g));
  const __mmask16 d =
      _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ);
  return _mm512_mask_blend_ps(
      d,
      _mm512_mask_blend_ps(c, MADD512(k, j, k),
                           _mm512_mul_ps(MADD512(s2, j, s2), s1)),
      _mm512_mul_ps(s1, s1));
}

Here's the numbers I got with the script I used for developing these functions:

   2.98601 ns 2000x run_expf()
   1.35154 ns 2000x run_llamafile_expf_sse2()
   1.16659 ns 2000x run_llamafile_expf_avx2()
   1.18844 ns 2000x run_llamafile_expf_avx512()

//       input          exp    llamafile   bad
//       =====          ===    =========   ===
//           0            1            1     0
//          -0            1            1     0
//         nan          nan          nan     0
//        -nan         -nan         -nan     0
//         inf          inf          inf     0
//        -inf            0            0     0
//          87  6.07603e+37  6.07603e+37     1
//          88  1.65164e+38  1.65164e+38     0
//     88.7229          inf          inf     0
//          89          inf          inf     0
//         -87  1.64581e-38  1.64581e-38     1
//         -90  8.19401e-40  8.19401e-40     0
//         -95  5.52112e-42  5.52112e-42     0
//        -100  3.78351e-44  3.78351e-44     0
//        -104            0            0     0
//    0.660001      1.93479      1.93479     1
//   -0.324231     0.723083     0.723083     0
//   0.0205384      1.02075      1.02075     0
//   -0.224604     0.798833     0.798833     1
//   -0.339606     0.712051      0.71205     1
//    0.211472       1.2355       1.2355     0
//    0.238942       1.2699       1.2699     0
//    -0.78286     0.457097     0.457097     0
4294967296 numbers tested successfully

jart avatar May 09 '24 14:05 jart

@ggerganov Running your command, I'm noticing the advantage here increases from 1.5x to 1.9x if we include AVX2. On znver4 if we also include avx512 then that goes up to 2.1x. I'd expect that to go higher in the future, since znver4 only really implements the AVX512 ISA and uses 2 cycles for each vector operation. So I've gone ahead and included the code for you.

jart avatar May 09 '24 16:05 jart

With AVX512, you may want to use vscalefps. It computes zmm0 = zmm1 * 2^{zmm2}, where all are floats.

It overflows and underflows properly, letting you remove checks + blends.

I have an implementation in Julia, e.g. a loop with 4x unrolling and interleaving.

L304:
	vmovups	zmm15, zmmword ptr [r11 + 4*rax]
	vmovups	zmm14, zmmword ptr [r11 + 4*rax + 64]
	vmovups	zmm13, zmmword ptr [r11 + 4*rax + 128]
	vmovups	zmm12, zmmword ptr [r11 + 4*rax + 192]
	vmovaps	zmm16, zmm1
	vfmadd213ps	zmm16, zmm15, zmm0      # zmm16 = (zmm15 * zmm16) + zmm0
	vmovaps	zmm17, zmm1
	vfmadd213ps	zmm17, zmm14, zmm0      # zmm17 = (zmm14 * zmm17) + zmm0
	vmovaps	zmm18, zmm1
	vfmadd213ps	zmm18, zmm13, zmm0      # zmm18 = (zmm13 * zmm18) + zmm0
	vmovaps	zmm19, zmm1
	vfmadd213ps	zmm19, zmm12, zmm0      # zmm19 = (zmm12 * zmm19) + zmm0
	vaddps	zmm16, zmm16, zmm2
	vaddps	zmm17, zmm17, zmm2
	vaddps	zmm18, zmm18, zmm2
	vaddps	zmm19, zmm19, zmm2
	vfmadd231ps	zmm15, zmm16, zmm3      # zmm15 = (zmm16 * zmm3) + zmm15
	vfmadd231ps	zmm14, zmm17, zmm3      # zmm14 = (zmm17 * zmm3) + zmm14
	vfmadd231ps	zmm13, zmm18, zmm3      # zmm13 = (zmm18 * zmm3) + zmm13
	vfmadd231ps	zmm12, zmm19, zmm3      # zmm12 = (zmm19 * zmm3) + zmm12
	vfmadd231ps	zmm15, zmm16, zmm4      # zmm15 = (zmm16 * zmm4) + zmm15
	vfmadd231ps	zmm14, zmm17, zmm4      # zmm14 = (zmm17 * zmm4) + zmm14
	vfmadd231ps	zmm13, zmm18, zmm4      # zmm13 = (zmm18 * zmm4) + zmm13
	vfmadd231ps	zmm12, zmm19, zmm4      # zmm12 = (zmm19 * zmm4) + zmm12
	vmovaps	zmm20, zmm6
	vfmadd213ps	zmm20, zmm15, zmm5      # zmm20 = (zmm15 * zmm20) + zmm5
	vmovaps	zmm21, zmm6
	vfmadd213ps	zmm21, zmm14, zmm5      # zmm21 = (zmm14 * zmm21) + zmm5
	vmovaps	zmm22, zmm6
	vfmadd213ps	zmm22, zmm13, zmm5      # zmm22 = (zmm13 * zmm22) + zmm5
	vmovaps	zmm23, zmm6
	vfmadd213ps	zmm23, zmm12, zmm5      # zmm23 = (zmm12 * zmm23) + zmm5
	vfmadd213ps	zmm20, zmm15, zmm7      # zmm20 = (zmm15 * zmm20) + zmm7
	vfmadd213ps	zmm21, zmm14, zmm7      # zmm21 = (zmm14 * zmm21) + zmm7
	vfmadd213ps	zmm22, zmm13, zmm7      # zmm22 = (zmm13 * zmm22) + zmm7
	vfmadd213ps	zmm23, zmm12, zmm7      # zmm23 = (zmm12 * zmm23) + zmm7
	vfmadd213ps	zmm20, zmm15, zmm8      # zmm20 = (zmm15 * zmm20) + zmm8
	vfmadd213ps	zmm21, zmm14, zmm8      # zmm21 = (zmm14 * zmm21) + zmm8
	vfmadd213ps	zmm22, zmm13, zmm8      # zmm22 = (zmm13 * zmm22) + zmm8
	vfmadd213ps	zmm23, zmm12, zmm8      # zmm23 = (zmm12 * zmm23) + zmm8
	vfmadd213ps	zmm20, zmm15, zmm9      # zmm20 = (zmm15 * zmm20) + zmm9
	vfmadd213ps	zmm21, zmm14, zmm9      # zmm21 = (zmm14 * zmm21) + zmm9
	vfmadd213ps	zmm22, zmm13, zmm9      # zmm22 = (zmm13 * zmm22) + zmm9
	vfmadd213ps	zmm23, zmm12, zmm9      # zmm23 = (zmm12 * zmm23) + zmm9
	vfmadd213ps	zmm20, zmm15, zmm10     # zmm20 = (zmm15 * zmm20) + zmm10
	vfmadd213ps	zmm21, zmm14, zmm10     # zmm21 = (zmm14 * zmm21) + zmm10
	vfmadd213ps	zmm22, zmm13, zmm10     # zmm22 = (zmm13 * zmm22) + zmm10
	vfmadd213ps	zmm23, zmm12, zmm10     # zmm23 = (zmm12 * zmm23) + zmm10
	vfmadd213ps	zmm20, zmm15, zmm11     # zmm20 = (zmm15 * zmm20) + zmm11
	vfmadd213ps	zmm21, zmm14, zmm11     # zmm21 = (zmm14 * zmm21) + zmm11
	vfmadd213ps	zmm22, zmm13, zmm11     # zmm22 = (zmm13 * zmm22) + zmm11
	vfmadd213ps	zmm23, zmm12, zmm11     # zmm23 = (zmm12 * zmm23) + zmm11
	vfmadd213ps	zmm20, zmm15, zmm11     # zmm20 = (zmm15 * zmm20) + zmm11
	vfmadd213ps	zmm21, zmm14, zmm11     # zmm21 = (zmm14 * zmm21) + zmm11
	vfmadd213ps	zmm22, zmm13, zmm11     # zmm22 = (zmm13 * zmm22) + zmm11
	vfmadd213ps	zmm23, zmm12, zmm11     # zmm23 = (zmm12 * zmm23) + zmm11
	vscalefps	zmm12, zmm20, zmm16, {rn-sae}
	vscalefps	zmm13, zmm21, zmm17, {rn-sae}
	vscalefps	zmm14, zmm22, zmm18, {rn-sae}
	vscalefps	zmm15, zmm23, zmm19, {rn-sae}
	vmovups	zmmword ptr [r14 + 4*rax], zmm12
	vmovups	zmmword ptr [r14 + 4*rax + 64], zmm13
	vmovups	zmmword ptr [r14 + 4*rax + 128], zmm14
	vmovups	zmmword ptr [r14 + 4*rax + 192], zmm15
	add	rax, 64
	cmp	rax, r10
	jl	L304

These gave me a significant performance improvement. If my test is correct, I got a maximum error <1 ULP at x=47.483456f.

What hardware are you on? I'm using skylake-avx512/cascadelake with 2x fma units. Zen4 or something like icelake-client/tigerlake likely won't benefit as much.

Note that it doesn't use a lookup table. My Float64/double implementation uses a 16-element lookup table via vpermi2pd. If we wanted, we could use a 32-element lookup table of floats via the same approach. vpermi2pd is much faster than gather, the cost of course being that our table has to fit into two registers.

chriselrod avatar May 16 '24 09:05 chriselrod

After this PR has been merged the server has been producing nondeterministic results when using >1 slots. Minimal example for reproduction:

make clean && make server
./server -m models/opt/llama_2-7b-q4_0.gguf --parallel 2 --threads 1

In another shell:

curl --request POST --url http://localhost:8080/completion --header "Content-Type: application/json" --data '{"prompt": "", "n_predict":10, "n_probs": 2, "temperature": -1}' | python3 -m json.tool

The token probabilities for the last token cycle between two values with every curl call. When using 4 slots the token probabilities cycle between 4 possible values.

JohannesGaessler avatar May 19 '24 13:05 JohannesGaessler

@chriselrod Could you help me modify my avx512 intrinsics to use _mm512_scalef_ps (vscalefps) like your code? I'm currently talking to ARM Limited about getting these functions into Glibc, since our code goes faster. https://github.com/ARM-software/optimized-routines/pull/69

jart avatar May 22 '24 21:05 jart

@jart Sure. If it helps, I just wrote a C++ implementation you can look at here: https://github.com/chriselrod/ExpAVX512 The source is in include/ExpAVX512.hpp.

The README contains benchmarks. I didn't benchmark against any other implementations, though (but, I think the assembly looks quite good when using clang++-18; I did not try other compilers).

I haven't done much analysis other than a glance to see that unrolling w/ interleaving and especially larger SIMD vectors boost performance on smaller array sizes that fit in cache on my desktop.

The basic alg for base^x (where ^ means pow, not xor) is based on:

x = r + n_float/log2(base), where n is integer-valued
base^x
= 2^(log2(base^x))
= 2^(x*log2(base))
= 2^(r*log2(base) + n_float)
= 2^(r*log2(base)) * 2^n_float
= base^r * 2^n_float

So, with that math in mind, the basic algorithm is:

  1. decompose x into r and n_float/log2(base). This is trivial when base=2 (note exp2 performs much better than exp in my benchmarks) thanks to vreduceps, an instruction which which computes r directly; as log2(2) = 1.0, we then have n_float = x-r. For other bases, we multiply by log2(base)while adding and subtracting a huge constant to cause rounding. Then forr, we subtract x - nfloat*(1/log2(base))in two steps (note1/log2(base) = log_base(2)), using the hi and lo parts of log_base(2)calculated in a higher precision to getr` with sufficient accuracy.
  2. Kernel to calculate base_to_rth_power = base^r
  3. vscapefps to get base_to_rth_power * 2^n_float.

chriselrod avatar May 26 '24 03:05 chriselrod

I think this PR breaks something - it causes stable-diffusion.cpp to generate completely black images for any input when merged into it. Reverting commit 934266c (and only this commit) solves the issue - and since the API is exactly the same and no other files are changed besides ggml.c, I suspect there is something incorrect with the implementation.

Giving a heads up to @leejet since this is not actually merged into stable-diffusion.cpp yet.

Only affects the CPU backend.

LostRuins avatar May 28 '24 12:05 LostRuins

I'm not skilled enough to fix it, but I just reverted to the precomputed silu table for ggml_vec_silu_f32 and it worked https://github.com/LostRuins/koboldcpp/commit/b5401a2901c12a3160ff1d034eef22277764b071

LostRuins avatar May 28 '24 14:05 LostRuins

@LostRuins Which instruction set do you observe to fail (ARM, AVX, ..)?

ggerganov avatar May 29 '24 17:05 ggerganov

Hmm I think it should be the AVX2/FMA implementation, although I am not sure if the same issue affects other instruction sets as well.

I just tested again without -mavx2 (which presumably should be using the SSE2 impl) and it also returns a black square.

LostRuins avatar May 30 '24 02:05 LostRuins

I just imported stable-diffusion.cpp into the llamafile codebase, which uses these expf() functions, and things work fine. I'm not seeing any black squares. I even enabled trapping math to be sure. There's not a single nan or overflow in the entire program. make -j32 o//stable-diffusion.cpp/main && o//stable-diffusion.cpp/main -m /weights/stable-diffusion-2-1.F16.gguf -p 'tortoiseshell cat named varu' -t 20 --trap. I'm not convinced there's anything wrong with this PR because I tested every single 32-bit floating point number.

jart avatar May 30 '24 08:05 jart

Perhaps you are right, maybe I am doing something else wrongly. Anyway I think we can leave it for now unless someone else encounters the same issue.

LostRuins avatar May 31 '24 03:05 LostRuins

I just updated the ggml of stable-diffusion.cpp to the latest commit. It works great, the image generation speed on the CPU has improved, and I haven't encountered any issues.

leejet avatar Jun 01 '24 05:06 leejet

Perhaps I made a mistake when merging it in. I will look through it again. Thanks.

LostRuins avatar Jun 02 '24 02:06 LostRuins

@ggerganov @jart @leejet Not entirely sure why, but I found the cause!

The reason is: I have been building with -Ofast

Before this, it has been working fine for me thus far, everything else (except CUDA) works fine with -Ofast, up until this PR.

Switching back to -O3 makes everything work again!

Also - I did dig a bit further into ggml_vec_silu_f32, the -Ofast incompatibility is with the

#elif defined(__AVX2__) && defined(__FMA__)
    for (; i + 7 < n; i += 8) {
        _mm256_storeu_ps(y + i, ggml_v_silu(_mm256_loadu_ps(x + i)));
    }
#elif defined(__SSE2__)
    for (; i + 3 < n; i += 4) {
        _mm_storeu_ps(y + i, ggml_v_silu(_mm_loadu_ps(x + i)));
    }

versions... but if I force a no intrinsic version using only the pure f32 CPU

  for (; i < n; ++i) {
        y[i] = ggml_silu_f32(x[i]);
    }

then everything works well even with -Ofast

Building directly from sd.cpp's cmake would by default pick -O3 so you wouldn't have seen any issues.

This is pretty interesting - do you think you could get it working with -Ofast? I do get significant speedups with -Ofast, so it would be nice to be able to build with that flag. Otherwise, do you know if its possible to instruct the compiler not to optimize that specific function, since compile flags are set on a per-file level - or is the only way to split the function out to a different file.

Or maybe there's a more elegant solution?

LostRuins avatar Jun 02 '24 09:06 LostRuins

If you want to dig into this more, look at the GCC compiler flags enabled by -Ofast and try to isolate which one is causing issues.

JohannesGaessler avatar Jun 02 '24 09:06 JohannesGaessler

I've narrowed it down to one of the aspects of -ffast-math, haven't figured out which one yet.

LostRuins avatar Jun 02 '24 10:06 LostRuins

Okay! I think I nailed down the problematic flag that causes this PR to break, it is -ffinite-math-only

I guess it could get rid of that flag... but I am curious if it's something that might be solvable, considering all other ops are fine.

LostRuins avatar Jun 02 '24 13:06 LostRuins

What could be happening is that the exponential function in SiLU, instead of flushing small values to 0, returns NaN or some other garbage. I've essentially had this same issue in CUDA for the exponentiation in FlashAttention. I fixed it by explicitly zeroing the exponentials of all values < -20.0f since those are going to be < 2e-9 and therefore negligible anyways.

The relevant FP32 code looks something like this:

const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps];
KQ_f_tmp[k0/WARP_SIZE] = expf(diff);
if (diff <= SOFTMAX_FTZ_THRESHOLD) {
    KQ_f_tmp[k0/WARP_SIZE] = 0.0f;
}
KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE];

JohannesGaessler avatar Jun 02 '24 13:06 JohannesGaessler

The INFINITY constant alone is used 83 times in the llama.cpp codebase, so compiling with -ffinite-math-only might not be a bright idea. If you want us to stop using infinity and start employing ugly workarounds instead, it'd help if you could communicate exactly what we stand to gain. You said faster. How much faster? On what specific hardware? Could you use the Linux Perf tool to isolate which specific code is impacted? E.g. perf record --call-graph dwarf ./main ... and perf report.

jart avatar Jun 02 '24 21:06 jart

If you want us to stop using infinity and start employing ugly workarounds instead, it'd help if you could communicate exactly what we stand to gain.

In my particular case (CUDA FP16 FlashAttention) I'm not sure the issue is fixable by a compiler flag; The NVCC documentation does not seem to mention any equivalent flag. I have not actually encountered the issue for FP32 precision but I still implemented the same flush-to-zero behavior for consistency.

I generally agree that just not using -ffinite-math-only is a better solution than a workaround and I did not mean to say that this is how the ggml code should be modified. My intent was to provide further context for what I think is the exact issue and to provide an example workaround that could be used for testing.

JohannesGaessler avatar Jun 02 '24 21:06 JohannesGaessler

Yeah now that I know what's wrong I'd probably just not use that flag. I'm actually more surprised that there have been no (apparent) issues with finite math prior to this one - so I guess 'what we stand to gain' is simply retaining existing compatibility with -Ofast when cuda is not used.

In any case I think we can consider the issue resolved then.

LostRuins avatar Jun 03 '24 01:06 LostRuins

The INFINITY constant alone is used 83 times in the llama.cpp codebase, so compiling with -ffinite-math-only might not be a bright idea. If you want us to stop using infinity and start employing ugly workarounds instead, it'd help if you could communicate exactly what we stand to gain. You said faster. How much faster? On what specific hardware?

Your accuracy tests run almost 4x faster walltime and 1.5x less CPU time under -Ofast on my 10980xe. Not all threads benefited equally, a normal run is dominated by a few slow threads.

> time ./exp_accuracy

   3.13724 ns 100000x run_expf()
   1.16232 ns 100000x run_llamafile_expf_avx512()
   1.15372 ns 100000x run_chris_expf_avx512()

//       input          exp            v   bad
//       =====          ===    =========   ===
//           0            1            1     0
//          -0            1            1     0
//         nan          nan          nan     0
//        -nan         -nan         -nan     0
//         inf          inf          inf     0
//        -inf            0            0     0
//          87  6.07603e+37  6.07603e+37     1
//          88  1.65164e+38  1.65164e+38     0
//     88.7229          inf          inf     0
//          89          inf          inf     0
//         -87  1.64581e-38  1.64581e-38     1
//         -90  8.19401e-40  8.19401e-40     0
//         -95  5.52112e-42  5.52112e-42     0
//        -100  3.78351e-44  3.78351e-44     0
//        -104            0            0     0
//    -2.45731    0.0856653    0.0856653     0
//    0.301039      1.35126      1.35126     0
//      4.7475      115.296      115.296     1
//    -3.74837     0.023556     0.023556     1
//     9.45433      12763.3      12763.3     0
//   0.0163593      1.01649      1.01649     0
//     1.72593      5.61775      5.61775     1
//    -13.9066  9.12951e-07  9.12951e-07     1
4294967296 numbers tested successfully

________________________________________________________
Executed in   12.25 secs    fish           external
   usr time  123.22 secs  325.00 micros  123.22 secs
   sys time    0.31 secs  161.00 micros    0.31 secs

chriselrod@fedora ~/D/p/c/misc> time ./exp_accuracy_fast

   2.99475 ns 100000x run_expf()
   1.15154 ns 100000x run_llamafile_expf_avx512()
   1.14108 ns 100000x run_chris_expf_avx512()

//       input          exp            v   bad
//       =====          ===    =========   ===
//           0            1            1     0
//          -0            1            1     0
//         nan          nan          nan     0
//        -nan         -nan         -nan     0
//         inf          inf          inf     0
//        -inf            0            0     0
//          87  6.07603e+37  6.07603e+37     1
//          88  1.65164e+38  1.65164e+38     0
//     88.7229          inf          inf     0
//          89          inf          inf     0
//         -87  1.64581e-38  1.64581e-38     1
//         -90            0            0     0
//         -95            0            0     0
//        -100            0            0     0
//        -104            0            0     0
//    -2.45731    0.0856653    0.0856653     0
//    0.301039      1.35126      1.35126     0
//      4.7475      115.296      115.296     1
//    -3.74837     0.023556     0.023556     1
//     9.45433      12763.3      12763.3     0
//   0.0163593      1.01649      1.01649     0
//     1.72593      5.61775      5.61775     1
//    -13.9066  9.12951e-07  9.12951e-07     1
4294967296 numbers tested successfully

________________________________________________________
Executed in    3.19 secs    fish           external
   usr time   80.03 secs  322.00 micros   80.03 secs
   sys time    0.24 secs  155.00 micros    0.24 secs

Note how expf(-90) and smaller are flushed to 0 in the fast version. I suspected that flushing denormals to 0 as the cause for the performance gain.

Denormals are extremely slow on Intel hardware (I believe the penalty is much lower on AMD).

To confirm, I built clang-19 from source, as it added the -mdaz-ftz flag to selectively enable flushing denormals to 0 without enabling any of the other -ffast-math optimizations:

> ~/Documents/libraries/llvm-project/build/bin/clang++ -O3 -mdaz-ftz -std=c++20 -fopenmp -march=native -o exp_accuracy_denormal ex
p_accuracy.cpp

> time ./exp_accuracy_denormal                                                                                                                                                                                                                                                             (base) 

   2.98692 ns 100000x run_expf()
   1.15956 ns 100000x run_llamafile_expf_avx512()
   1.15179 ns 100000x run_chris_expf_avx512()

//       input          exp            v   bad
//       =====          ===    =========   ===
//           0            1            1     0
//          -0            1            1     0
//         nan          nan          nan     0
//        -nan         -nan         -nan     0
//         inf          inf          inf     0
//        -inf            0            0     0
//          87  6.07603e+37  6.07603e+37     1
//          88  1.65164e+38  1.65164e+38     0
//     88.7229          inf          inf     0
//          89          inf          inf     0
//         -87  1.64581e-38  1.64581e-38     1
//         -90            0            0     0
//         -95            0            0     0
//        -100            0            0     0
//        -104            0            0     0
//    -2.45731    0.0856653    0.0856653     0
//    0.301039      1.35126      1.35126     0
//      4.7475      115.296      115.296     1
//    -3.74837     0.023556     0.023556     1
//     9.45433      12763.3      12763.3     0
//   0.0163593      1.01649      1.01649     0
//     1.72593      5.61775      5.61775     1
//    -13.9066  9.12951e-07  9.12951e-07     1
4294967296 numbers tested successfully

________________________________________________________
Executed in    3.05 secs    fish           external
   usr time   74.09 secs  245.00 micros   74.09 secs
   sys time    0.27 secs  126.00 micros    0.27 secs

I see that essentially the entirety of the performance improvement in that benchmark comes flushing denormals to 0.

@LostRuins it would be nice if you could pin down what is causing the performance improvements you see. I suspect real applications aren't going to encounter many denormals; the accuracy test tries all float values as input. Many of them lead to such problematic performance, but that is far from a normal distribution of inputs.

chriselrod avatar Jun 03 '24 04:06 chriselrod

I concur. I tested every single one of the -ffast-math flags and I couldn't find any improvements in my accuracy script. Except for -funsafe-math-optimizations which caused a 20% reduction in OS-measured CPU time. So per your suggestion, I got rid of that flag, and used the following code instead, at the beginning of main():

#ifdef __x86_64__
    //
    // Enable hardware optimizations in violation of the IEEE standard.
    //
    // - 0x0040 enables "DAZ: Denormals Are Zeros" in MXCSR. This causes the
    //   processor to turn denormal inputs into zero, before computing them.
    //   See Intel Manual Vol. 1 §10.2.3.4
    //
    // - 0x8000 enables "FTZ: Flush To Zero" in MXCSR. This means a floating
    //   point operation that results in underflow will be set to zero, with
    //   the same sign, rather than producing a denormalized output. It will
    //   happen only if underflow trapping hasnt been enabled. See the Intel
    //   Manual Vol. 1 §10.2.3.3.
    //
    unsigned mxcsr;
    asm("stmxcsr\t%0" : "=m"(mxcsr));
    mxcsr |= 0x8040;
    asm("ldmxcsr\t%0" : /* no inputs */ : "m"(mxcsr));
#endif

Then I saw the same 20% improvement. So I think what was happening is the -funsafe-math-optimizations flags asks the libc runtime to run the above assembly code in _start().

That doesn't mean this will help with inference. I've yet to find an LLM that underflows enough where FTZ will matter. For example, if I ask Mistral 7b v0.3 to process a 215 token prompt, then the process underflow 4500 times. No difference in clock() time or latency.

jart avatar Jun 03 '24 08:06 jart

Can we enforce a compile error if -ffinite-math-only is used during compilation in order to prevent such issues in the future?

ggerganov avatar Jun 03 '24 11:06 ggerganov

@ggerganov I think there is a flag -fno-finite-math-only which should enforce this, even if -Ofast or -ffast-math is used. It allows preserving the rest of the optimizations and only turning off this one. Perhaps that could be used?

LostRuins avatar Jun 03 '24 14:06 LostRuins

I was thinking a change in the source code rather - the build system is not standardised, so nothing prevents 3rd party projects from building with -ffinite-math-only. Maybe we can check __FINITE_MATH_ONLY__ and either fallback to scalar implementation or assert if it is equal to 1

ggerganov avatar Jun 03 '24 16:06 ggerganov