whisper.cpp
whisper.cpp copied to clipboard
port fft to pocketfft
I recently surveyed all of the available fft implementations and pocketfft is extremely competitive even with Accelerate, while being a cross-platform C++11 header-only library. (I've worked with a ton of fft implementations in the past, e.g. fftw, fftpack, IPP, MKL, mufft, and Accelerate... pocketfft is my current favorite for both performance and simplicity)
I did some other minor cleanup in the log_mel function while I was in there.
I manually checked and both the FFT outputs and transcription outputs seem to match before/after the change.
Here are some timings on an M1 Max with thread counts (1, 2, 4, 8). Seems like a pretty clean 6x speedup on the mel time.
current master
% for t in 1 1 2 2 4 4 8 8; do echo -n "[$t] "; ./main-cpu -m models/ggml-tiny.bin samples/jfk.wav -t $t 2>&1 | grep 'mel time'; done
[1] whisper_print_timings: mel time = 86.76 ms
[1] whisper_print_timings: mel time = 89.72 ms
[2] whisper_print_timings: mel time = 47.61 ms
[2] whisper_print_timings: mel time = 43.96 ms
[4] whisper_print_timings: mel time = 23.25 ms
[4] whisper_print_timings: mel time = 23.16 ms
[8] whisper_print_timings: mel time = 12.14 ms
[8] whisper_print_timings: mel time = 12.35 ms
pocketfft
% for t in 1 1 2 2 4 4 8 8; do echo -n "[$t] "; ./main-pocketfft -m models/ggml-tiny.bin samples/jfk.wav -t $t 2>&1 | grep 'mel time'; done
[1] whisper_print_timings: mel time = 13.52 ms
[1] whisper_print_timings: mel time = 13.48 ms
[2] whisper_print_timings: mel time = 7.04 ms
[2] whisper_print_timings: mel time = 7.11 ms
[4] whisper_print_timings: mel time = 3.79 ms
[4] whisper_print_timings: mel time = 3.89 ms
[8] whisper_print_timings: mel time = 2.10 ms
[8] whisper_print_timings: mel time = 2.07 ms
Here's full output so you can see the overall timings with transcriptions:
current master
% ./main-cpu -m models/ggml-tiny.bin samples/jfk.wav -t 8
whisper_init_from_file: loading model from 'models/ggml-tiny.bin'
whisper_model_load: loading model
whisper_model_load: n_vocab = 51865
whisper_model_load: n_audio_ctx = 1500
whisper_model_load: n_audio_state = 384
whisper_model_load: n_audio_head = 6
whisper_model_load: n_audio_layer = 4
whisper_model_load: n_text_ctx = 448
whisper_model_load: n_text_state = 384
whisper_model_load: n_text_head = 6
whisper_model_load: n_text_layer = 4
whisper_model_load: n_mels = 80
whisper_model_load: f16 = 1
whisper_model_load: type = 1
whisper_model_load: mem required = 127.00 MB (+ 3.00 MB per decoder)
whisper_model_load: kv self size = 2.62 MB
whisper_model_load: kv cross size = 8.79 MB
whisper_model_load: adding 1608 extra tokens
whisper_model_load: model ctx = 73.58 MB
whisper_model_load: model size = 73.54 MB
system_info: n_threads = 8 / 10 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 |
main: processing 'samples/jfk.wav' (176000 samples, 11.0 sec), 8 threads, 1 processors, lang = en, task = transcribe, timestamps = 1 ...
[00:00:00.000 --> 00:00:11.000] And so my fellow Americans ask not what your country can do for you ask what you can do for your country.
whisper_print_timings: fallbacks = 0 p / 0 h
whisper_print_timings: load time = 61.29 ms
whisper_print_timings: mel time = 12.14 ms
whisper_print_timings: sample time = 10.67 ms / 25 runs ( 0.43 ms per run)
whisper_print_timings: encode time = 93.18 ms / 1 runs ( 93.18 ms per run)
whisper_print_timings: decode time = 45.95 ms / 25 runs ( 1.84 ms per run)
whisper_print_timings: total time = 225.19 ms
pocketfft
% ./main-pocketfft -m models/ggml-tiny.bin samples/jfk.wav -t 8
whisper_init_from_file_no_state: loading model from 'models/ggml-tiny.bin'
whisper_model_load: loading model
whisper_model_load: n_vocab = 51865
whisper_model_load: n_audio_ctx = 1500
whisper_model_load: n_audio_state = 384
whisper_model_load: n_audio_head = 6
whisper_model_load: n_audio_layer = 4
whisper_model_load: n_text_ctx = 448
whisper_model_load: n_text_state = 384
whisper_model_load: n_text_head = 6
whisper_model_load: n_text_layer = 4
whisper_model_load: n_mels = 80
whisper_model_load: f16 = 1
whisper_model_load: type = 1
whisper_model_load: mem required = 127.00 MB (+ 3.00 MB per decoder)
whisper_model_load: adding 1608 extra tokens
whisper_model_load: model ctx = 73.58 MB
whisper_model_load: model size = 73.54 MB
whisper_init_state: kv self size = 2.62 MB
whisper_init_state: kv cross size = 8.79 MB
system_info: n_threads = 8 / 10 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 |
main: processing 'samples/jfk.wav' (176000 samples, 11.0 sec), 8 threads, 1 processors, lang = en, task = transcribe, timestamps = 1 ...
[00:00:00.000 --> 00:00:11.000] And so my fellow Americans ask not what your country can do for you ask what you can do for your country.
whisper_print_timings: load time = 59.35 ms
whisper_print_timings: fallbacks = 0 p / 0 h
whisper_print_timings: mel time = 2.18 ms
whisper_print_timings: sample time = 10.67 ms / 25 runs ( 0.43 ms per run)
whisper_print_timings: encode time = 102.52 ms / 1 runs ( 102.52 ms per run)
whisper_print_timings: decode time = 35.36 ms / 25 runs ( 1.41 ms per run)
whisper_print_timings: total time = 216.80 ms
I also optimized the log_mel matmul a bit, here are new numbers. The log_mel step now seems about 10x faster than the original. 1-thread log_mel is now 40% faster than the original 8-thread log_mel.
% for t in 1 1 2 2 4 4 8 8; do echo -n "[$t] "; ./main-pocketfft -m models/ggml-tiny.bin samples/jfk.wav -t $t 2>&1 | grep 'mel time'; done
[1] whisper_print_timings: mel time = 7.75 ms
[1] whisper_print_timings: mel time = 7.71 ms
[2] whisper_print_timings: mel time = 4.10 ms
[2] whisper_print_timings: mel time = 4.23 ms
[4] whisper_print_timings: mel time = 2.30 ms
[4] whisper_print_timings: mel time = 2.21 ms
[8] whisper_print_timings: mel time = 1.31 ms
[8] whisper_print_timings: mel time = 1.30 ms
Adding some benchmarks on my M1 Max here for a fairly small file in the WASM build:
operator(): processing 3151108 samples, 196.9 sec, 8 threads, 1 processors, lang = en, task = transcribe ...
MASTER (chrome):
whisper_print_timings: fallbacks = 1 p / 0 h
whisper_print_timings: load time = 132.53 ms
whisper_print_timings: mel time = 1488.02 ms
whisper_print_timings: sample time = 493.97 ms / 655 runs ( 0.75 ms per run)
whisper_print_timings: encode time = 21023.62 ms / 9 runs ( 2335.96 ms per run)
whisper_print_timings: decode time = 10138.27 ms / 655 runs ( 15.48 ms per run)
whisper_print_timings: total time = 33402.97 ms
POCKETFFT (chrome):
whisper_print_timings: load time = 118.28 ms
whisper_print_timings: fallbacks = 0 p / 0 h
whisper_print_timings: mel time = 337.71 ms
whisper_print_timings: sample time = 478.10 ms / 645 runs ( 0.74 ms per run)
whisper_print_timings: encode time = 18374.29 ms / 8 runs ( 2296.79 ms per run)
whisper_print_timings: decode time = 11633.24 ms / 645 runs ( 18.04 ms per run)
whisper_print_timings: total time = 31069.60 ms
MASTER (firefox):
whisper_print_timings: fallbacks = 1 p / 0 h
whisper_print_timings: load time = 118.54 ms
whisper_print_timings: mel time = 1933.18 ms
whisper_print_timings: sample time = 662.88 ms / 710 runs ( 0.93 ms per run)
whisper_print_timings: encode time = 17380.68 ms / 8 runs ( 2172.59 ms per run)
whisper_print_timings: decode time = 11267.72 ms / 710 runs ( 15.87 ms per run)
whisper_print_timings: total time = 31497.16 ms
POCKETFFT (firefox):
whisper_print_timings: load time = 103.44 ms
whisper_print_timings: fallbacks = 1 p / 0 h
whisper_print_timings: mel time = 408.76 ms
whisper_print_timings: sample time = 613.67 ms / 651 runs ( 0.94 ms per run)
whisper_print_timings: encode time = 18860.96 ms / 9 runs ( 2095.66 ms per run)
whisper_print_timings: decode time = 8970.86 ms / 651 runs ( 13.78 ms per run)
whisper_print_timings: total time = 29103.28 ms
@lunixbochs It's usually better to put licenses of subprojects used into separate license files such as LICENSE.pocketfft, so the original file still stays readable.
@lunixbochs It's usually better to put licenses of subprojects used into separate license files such as
LICENSE.pocketfft, so the original file still stays readable.
I force pushed with a license file split.
Here's my machine with a much longer test (1h51m)
main -t 4
whisper_print_timings: fallbacks = 9 p / 24 h
whisper_print_timings: load time = 64.84 ms
whisper_print_timings: mel time = 13256.59 ms
whisper_print_timings: sample time = 13183.77 ms / 28647 runs ( 0.46 ms per run)
whisper_print_timings: encode time = 85530.47 ms / 681 runs ( 125.60 ms per run)
whisper_print_timings: decode time = 63957.57 ms / 28616 runs ( 2.24 ms per run)
whisper_print_timings: total time = 176139.69 ms
main -t 8
whisper_print_timings: fallbacks = 42 p / 115 h
whisper_print_timings: load time = 71.17 ms
whisper_print_timings: mel time = 6870.31 ms
whisper_print_timings: sample time = 16781.76 ms / 32811 runs ( 0.51 ms per run)
whisper_print_timings: encode time = 52132.50 ms / 511 runs ( 102.02 ms per run)
whisper_print_timings: decode time = 67843.08 ms / 32646 runs ( 2.08 ms per run)
whisper_print_timings: total time = 143864.27 ms
this branch -t 4
whisper_print_timings: load time = 63.09 ms
whisper_print_timings: fallbacks = 16 p / 46 h
whisper_print_timings: mel time = 1275.09 ms
whisper_print_timings: sample time = 13702.99 ms / 29504 runs ( 0.46 ms per run)
whisper_print_timings: encode time = 89079.73 ms / 705 runs ( 126.35 ms per run)
whisper_print_timings: decode time = 65472.02 ms / 29444 runs ( 2.22 ms per run)
whisper_print_timings: total time = 169750.16 ms
this branch -t 8
whisper_print_timings: load time = 62.25 ms
whisper_print_timings: fallbacks = 24 p / 54 h
whisper_print_timings: mel time = 686.02 ms
whisper_print_timings: sample time = 19117.31 ms / 39097 runs ( 0.49 ms per run)
whisper_print_timings: encode time = 41683.73 ms / 414 runs ( 100.69 ms per run)
whisper_print_timings: decode time = 73066.01 ms / 39010 runs ( 1.87 ms per run)
whisper_print_timings: total time = 134783.25 ms
@lunixbochs
Thanks for this contribution! Unfortunately, I don't want to merge pocketfft into the project. It's a design decision to keep things minimal and in this regard, the FFT implementation is up to the user to select. We provide just a very basic built-in implementation. In the future, we could extend the C-style API with a mechanism to provide external callback that computes the FFT and then the user will be able to provide any implementation they would like.
However your improvements in the matrix multiplication in the log-mel computation are useful, so feel free to create a new PR with just that.
I think this is a mistake, unless you're planning to depend on a BLAS (which IMO is a much more complicated and heavy thing for users to manage, and isn't going to help the wasm user who posted here). From experience, very it's easy to mess up FFT code. Basically every FFT library has subtle differences, it can take care and experience to make the outputs of any two libraries match (or to even notice that you've made a mistake).