llama.cpp
llama.cpp copied to clipboard
Continuous layouts for quantization q4_0c
Adds a q4_0c type that corresponds to the q4_0 layout but with a different memory layout.
Draft status, currently only accelerated for AVX-512, will add a PoC of Neon acceleration but wanted to put this out there since there is some experimentation with quantization formats going on now.
The layout consists of all the quantized values first in blocks of 128 nibbles, followed by all the scales. The nibbles within a block are laid out consecutively in the lower nibbles, and then consecutively in the higher nibbles. For dot products we use a q8_0c format, with all the qs bytes followed by all the scales.
The big win is for architectures with larger registers like AVX-512, that can now get two continuous blocks of qs by doing roughly
xqs01 = xqs0123 & 0x0f0f...
xqs23 = xqs0123 >> 4 & 0x0f0f...
The dot product implementation here borrows from @dfyz's implementation in #933, but becomes simpler because we don't need to do tricks with the byte layout.
Besides the simplified implementation there is also a small improvement in performance:
llama_print_timings: prompt eval time = 665.66 ms / 6 tokens ( 110.94 ms per token)
llama_print_timings: total time = 15398.10 ms
vs
llama_print_timings: prompt eval time = 449.19 ms / 6 tokens ( 74.86 ms per token)
llama_print_timings: total time = 13557.80 ms
The SIMD implementation with 128-bit registers like Neon should look very similar to the current implementations, with similar speeds. Possibly some benefit from doing only aligned loads. The scalar implementations are slightly more complex but I do not see any degraded performance.
Perplexity should be exactly the same as q4_0.
Example timings (7B)
system_info: n_threads = 4 / 8 | AVX = 1 | AVX2 = 1 | AVX512 = 1 | AVX512_VBMI = 1 | AVX512_VNNI = 1 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 0 | SSE3 = 1 | VSX = 0 |
Current master q4_0:
llama_print_timings: load time = 797.54 ms
llama_print_timings: sample time = 69.77 ms / 100 runs ( 0.70 ms per run)
llama_print_timings: prompt eval time = 665.66 ms / 6 tokens ( 110.94 ms per token)
llama_print_timings: eval time = 14528.96 ms / 99 runs ( 146.76 ms per run)
llama_print_timings: total time = 15398.10 ms
q4_0 with #933
llama_print_timings: load time = 751.74 ms
llama_print_timings: sample time = 89.38 ms / 100 runs ( 0.89 ms per run)
llama_print_timings: prompt eval time = 620.75 ms / 6 tokens ( 103.46 ms per token)
llama_print_timings: eval time = 13475.35 ms / 99 runs ( 136.11 ms per run)
llama_print_timings: total time = 14318.72 ms
continuous layout q4_0c:
llama_print_timings: load time = 596.51 ms
llama_print_timings: sample time = 73.39 ms / 100 runs ( 0.73 ms per run)
llama_print_timings: prompt eval time = 449.19 ms / 6 tokens ( 74.86 ms per token)
llama_print_timings: eval time = 12885.82 ms / 99 runs ( 130.16 ms per run)
llama_print_timings: total time = 13557.80 ms
Todos:
- [X] AVX-512 acceleration
- [x] PoC ARM NEON acceleration
- [ ] Investigate performance on M1
- [ ] test out fully aligned loads by aligning the quantized file to 64 bytes
Future improvements:
- PoC for q4_2-like quantization
- PoC 3-bit quantization
- Support row lengths not divisible by 128. I don't see any big obstacles, will require some extra code for the leftover block, and padding of rows to keep alignment.
What do you think about having two separate arrays, one for qs and one for scales?
@unbounded
Thanks for the efforts. Last few days we were speed running various quantization approaches with focus on ARM NEON, but I think and hope that we are finally converging to the correct solution - Q4_2
and Q4_3
.
I hope in the next couple of days we confirm that we will proceed with these quantization strategies and merge the ARM NEOM implementations to master
. At this point, we can investigate alternative memory layouts that are suitable for other arches and also do not degrade the ARM NEON implementation (potentially like the proposed one in this PR)
What do you think about having two separate arrays, one for qs and one for scales?
Not sure what you mean here, you mean keeping them as two separate allocations? That would somewhat simplify alignment, but I think it would be hard to do that generally for different formats, e.g. q4_1 might use three "arrays"
I'll hold a bit until it stabilizes, but it should be straightforward to test the same approach for the q4_2 format.
Added SIMD for Arm Neon - it's almost identical to q4_0 except we don't need the vuzp1q_s8
instructions anymore.
I don't have an M1 to test on, but got some timings on an Ampere Altra VM:
system_info: n_threads = 4 / 4 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 0 | SSE3 = 0 | VSX = 0 |
q4_0 7B:
llama_print_timings: load time = 1560.01 ms
llama_print_timings: sample time = 99.38 ms / 100 runs ( 0.99 ms per run)
llama_print_timings: prompt eval time = 949.90 ms / 6 tokens ( 158.32 ms per token)
llama_print_timings: eval time = 17879.96 ms / 99 runs ( 180.61 ms per run)
llama_print_timings: total time = 19541.73 ms
q4_0c 7B:
llama_print_timings: load time = 1460.82 ms
llama_print_timings: sample time = 99.60 ms / 100 runs ( 1.00 ms per run)
llama_print_timings: prompt eval time = 865.02 ms / 6 tokens ( 144.17 ms per token)
llama_print_timings: eval time = 19523.40 ms / 99 runs ( 197.21 ms per run)
llama_print_timings: total time = 21086.20 ms
Prefetching seems to be important with this layout, probably the extra memory accesses confuse the hardware prefetcher.
AVX-512: q4_0c 7B, without prefetch:
llama_print_timings: prompt eval time = 449.19 ms / 6 tokens ( 74.86 ms per token)
llama_print_timings: eval time = 12885.82 ms / 99 runs ( 130.16 ms per run)
q4_0c 7B, with prefetch:
llama_print_timings: prompt eval time = 399.98 ms / 6 tokens ( 66.66 ms per token)
llama_print_timings: eval time = 9684.23 ms / 99 runs ( 97.82 ms per run)
Arm Neon: q4_0c 7B, without prefetch:
llama_print_timings: prompt eval time = 865.02 ms / 6 tokens ( 144.17 ms per token)
llama_print_timings: eval time = 19523.40 ms / 99 runs ( 197.21 ms per run)
q4_0c 7B, with prefetch:
llama_print_timings: prompt eval time = 777.49 ms / 6 tokens ( 129.58 ms per token)
llama_print_timings: eval time = 13856.25 ms / 99 runs ( 139.96 ms per run)
On M1 Pro, Q4_0c
is currently more than 20% slower than Q4_0
using 8 threads (and even more for 4 threads):
ggerganov Georgis-MBP ~/development/github/llama.cpp
20:15:59 ⚓ master-50cb666-9-g58e10f2 8⎘ $ make -j && ./main -m ./models/7B/ggml-model-q4_0.bin -p "I believe the meaning of life is" -c 2048 -n 512 --ignore-eos -s 5 -n 64 -t 4
I llama.cpp build info:
I UNAME_S: Darwin
I UNAME_P: arm
I UNAME_M: arm64
I CFLAGS: -I. -O3 -DNDEBUG -std=c11 -fPIC -Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion -Wshadow -Wstrict-prototypes -Wpointer-arith -pthread -DGGML_USE_ACCELERATE
I CXXFLAGS: -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wno-multichar -pthread
I LDFLAGS: -framework Accelerate
I CC: Apple clang version 14.0.3 (clang-1403.0.22.14.1)
I CXX: Apple clang version 14.0.3 (clang-1403.0.22.14.1)
make: Nothing to be done for `default'.
main: seed = 5
llama.cpp: loading model from ./models/7B/ggml-model-q4_0.bin
llama_model_load_internal: format = ggjt v1 (latest)
llama_model_load_internal: n_vocab = 32000
llama_model_load_internal: n_ctx = 2048
llama_model_load_internal: n_embd = 4096
llama_model_load_internal: n_mult = 256
llama_model_load_internal: n_head = 32
llama_model_load_internal: n_layer = 32
llama_model_load_internal: n_rot = 128
llama_model_load_internal: ftype = 2 (mostly Q4_0)
llama_model_load_internal: n_ff = 11008
llama_model_load_internal: n_parts = 1
llama_model_load_internal: model size = 7B
llama_model_load_internal: ggml ctx size = 59.11 KB
llama_model_load_internal: mem required = 5809.32 MB (+ 1026.00 MB per state)
llama_init_from_file: kv self size = 1024.00 MB
system_info: n_threads = 4 / 10 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 |
sampling: temp = 0.800000, top_k = 40, top_p = 0.950000, repeat_last_n = 64, repeat_penalty = 1.100000
generate: n_ctx = 2048, n_batch = 8, n_predict = 64, n_keep = 0
I believe the meaning of life is to learn, live and love.
There are three keys to living a happy life: Love what you do, Work with passionate people, Serve others who need your help.
Doing the things that matter most in life.
Being able to see beauty in everything around you.
The ability to
llama_print_timings: load time = 507.55 ms
llama_print_timings: sample time = 47.07 ms / 64 runs ( 0.74 ms per run)
llama_print_timings: prompt eval time = 498.11 ms / 8 tokens ( 62.26 ms per token)
llama_print_timings: eval time = 3561.34 ms / 63 runs ( 56.53 ms per run)
llama_print_timings: total time = 4116.99 ms
ggerganov Georgis-MBP ~/development/github/llama.cpp
20:16:13 ⚓ master-50cb666-9-g58e10f2 8⎘ $ make -j && ./main -m ./models/7B/ggml-model-q4_0c.bin -p "I believe the meaning of life is" -c 2048 -n 512 --ignore-eos -s 5 -n 64 -t 4
I llama.cpp build info:
I UNAME_S: Darwin
I UNAME_P: arm
I UNAME_M: arm64
I CFLAGS: -I. -O3 -DNDEBUG -std=c11 -fPIC -Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion -Wshadow -Wstrict-prototypes -Wpointer-arith -pthread -DGGML_USE_ACCELERATE
I CXXFLAGS: -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wno-multichar -pthread
I LDFLAGS: -framework Accelerate
I CC: Apple clang version 14.0.3 (clang-1403.0.22.14.1)
I CXX: Apple clang version 14.0.3 (clang-1403.0.22.14.1)
make: Nothing to be done for `default'.
main: seed = 5
llama.cpp: loading model from ./models/7B/ggml-model-q4_0c.bin
llama_model_load_internal: format = ggjt v1 (latest)
llama_model_load_internal: n_vocab = 32000
llama_model_load_internal: n_ctx = 2048
llama_model_load_internal: n_embd = 4096
llama_model_load_internal: n_mult = 256
llama_model_load_internal: n_head = 32
llama_model_load_internal: n_layer = 32
llama_model_load_internal: n_rot = 128
llama_model_load_internal: ftype = 7 (mostly Q4_0C)
llama_model_load_internal: n_ff = 11008
llama_model_load_internal: n_parts = 1
llama_model_load_internal: model size = 7B
llama_model_load_internal: ggml ctx size = 59.11 KB
llama_model_load_internal: mem required = 5809.32 MB (+ 1026.00 MB per state)
llama_init_from_file: kv self size = 1024.00 MB
system_info: n_threads = 4 / 10 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 |
sampling: temp = 0.800000, top_k = 40, top_p = 0.950000, repeat_last_n = 64, repeat_penalty = 1.100000
generate: n_ctx = 2048, n_batch = 8, n_predict = 64, n_keep = 0
I believe the meaning of life is to serve others with love and compassion. To share my time, energy and money to make this world a better place. That's where my passion comes from!
My parents gave me so much love, that I want to give it back to people who need it most. There are so many children around
llama_print_timings: load time = 668.47 ms
llama_print_timings: sample time = 47.20 ms / 64 runs ( 0.74 ms per run)
llama_print_timings: prompt eval time = 659.63 ms / 8 tokens ( 82.45 ms per token)
llama_print_timings: eval time = 5920.76 ms / 63 runs ( 93.98 ms per run)
llama_print_timings: total time = 6637.47 ms
ggerganov Georgis-MBP ~/development/github/llama.cpp
20:16:25 ⚓ master-50cb666-9-g58e10f2 8⎘ $ make -j && ./main -m ./models/7B/ggml-model-q4_0.bin -p "I believe the meaning of life is" -c 2048 -n 512 --ignore-eos -s 5 -n 64 -t 8
I llama.cpp build info:
I UNAME_S: Darwin
I UNAME_P: arm
I UNAME_M: arm64
I CFLAGS: -I. -O3 -DNDEBUG -std=c11 -fPIC -Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion -Wshadow -Wstrict-prototypes -Wpointer-arith -pthread -DGGML_USE_ACCELERATE
I CXXFLAGS: -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wno-multichar -pthread
I LDFLAGS: -framework Accelerate
I CC: Apple clang version 14.0.3 (clang-1403.0.22.14.1)
I CXX: Apple clang version 14.0.3 (clang-1403.0.22.14.1)
make: Nothing to be done for `default'.
main: seed = 5
llama.cpp: loading model from ./models/7B/ggml-model-q4_0.bin
llama_model_load_internal: format = ggjt v1 (latest)
llama_model_load_internal: n_vocab = 32000
llama_model_load_internal: n_ctx = 2048
llama_model_load_internal: n_embd = 4096
llama_model_load_internal: n_mult = 256
llama_model_load_internal: n_head = 32
llama_model_load_internal: n_layer = 32
llama_model_load_internal: n_rot = 128
llama_model_load_internal: ftype = 2 (mostly Q4_0)
llama_model_load_internal: n_ff = 11008
llama_model_load_internal: n_parts = 1
llama_model_load_internal: model size = 7B
llama_model_load_internal: ggml ctx size = 59.11 KB
llama_model_load_internal: mem required = 5809.32 MB (+ 1026.00 MB per state)
llama_init_from_file: kv self size = 1024.00 MB
system_info: n_threads = 8 / 10 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 |
sampling: temp = 0.800000, top_k = 40, top_p = 0.950000, repeat_last_n = 64, repeat_penalty = 1.100000
generate: n_ctx = 2048, n_batch = 8, n_predict = 64, n_keep = 0
I believe the meaning of life is to learn, live and love.
There are three keys to living a happy life: Love what you do, Work with passionate people, Serve others who need your help.
Doing the things that matter most in life.
Being able to see beauty in everything around you.
The ability to
llama_print_timings: load time = 335.87 ms
llama_print_timings: sample time = 47.24 ms / 64 runs ( 0.74 ms per run)
llama_print_timings: prompt eval time = 326.36 ms / 8 tokens ( 40.79 ms per token)
llama_print_timings: eval time = 3228.42 ms / 63 runs ( 51.24 ms per run)
llama_print_timings: total time = 3612.50 ms
ggerganov Georgis-MBP ~/development/github/llama.cpp
20:16:35 ⚓ master-50cb666-9-g58e10f2 8⎘ $ make -j && ./main -m ./models/7B/ggml-model-q4_0c.bin -p "I believe the meaning of life is" -c 2048 -n 512 --ignore-eos -s 5 -n 64 -t 8
I llama.cpp build info:
I UNAME_S: Darwin
I UNAME_P: arm
I UNAME_M: arm64
I CFLAGS: -I. -O3 -DNDEBUG -std=c11 -fPIC -Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion -Wshadow -Wstrict-prototypes -Wpointer-arith -pthread -DGGML_USE_ACCELERATE
I CXXFLAGS: -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wno-multichar -pthread
I LDFLAGS: -framework Accelerate
I CC: Apple clang version 14.0.3 (clang-1403.0.22.14.1)
I CXX: Apple clang version 14.0.3 (clang-1403.0.22.14.1)
make: Nothing to be done for `default'.
main: seed = 5
llama.cpp: loading model from ./models/7B/ggml-model-q4_0c.bin
llama_model_load_internal: format = ggjt v1 (latest)
llama_model_load_internal: n_vocab = 32000
llama_model_load_internal: n_ctx = 2048
llama_model_load_internal: n_embd = 4096
llama_model_load_internal: n_mult = 256
llama_model_load_internal: n_head = 32
llama_model_load_internal: n_layer = 32
llama_model_load_internal: n_rot = 128
llama_model_load_internal: ftype = 7 (mostly Q4_0C)
llama_model_load_internal: n_ff = 11008
llama_model_load_internal: n_parts = 1
llama_model_load_internal: model size = 7B
llama_model_load_internal: ggml ctx size = 59.11 KB
llama_model_load_internal: mem required = 5809.32 MB (+ 1026.00 MB per state)
llama_init_from_file: kv self size = 1024.00 MB
system_info: n_threads = 8 / 10 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 |
sampling: temp = 0.800000, top_k = 40, top_p = 0.950000, repeat_last_n = 64, repeat_penalty = 1.100000
generate: n_ctx = 2048, n_batch = 8, n_predict = 64, n_keep = 0
I believe the meaning of life is to serve others with love and compassion. To share my time, energy and money to make this world a better place. That's where my passion comes from!
My parents gave me so much love, that I want to give it back to people who need it most. There are so many children around
llama_print_timings: load time = 452.52 ms
llama_print_timings: sample time = 46.99 ms / 64 runs ( 0.73 ms per run)
llama_print_timings: prompt eval time = 443.12 ms / 8 tokens ( 55.39 ms per token)
llama_print_timings: eval time = 4073.24 ms / 63 runs ( 64.65 ms per run)
llama_print_timings: total time = 4573.76 ms
ggerganov Georgis-MBP ~/development/github/llama.cpp
20:16:46 ⚓ master-50cb666-9-g58e10f2 8⎘ $
Btw, without prefetching (i.e. previous commit) Q4_0c
is about the same as Q4_0
:
20:20:57 ⚓ master-50cb666-8-g64a6a29 8⎘ $ make -j && ./main -m ./models/7B/ggml-model-q4_0.bin -p "I believe the meaning of life is" -c 2048 -n 512 --ignore-eos -s 5 -n 64 -t 4
I llama.cpp build info:
I UNAME_S: Darwin
I UNAME_P: arm
I UNAME_M: arm64
I CFLAGS: -I. -O3 -DNDEBUG -std=c11 -fPIC -Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion -Wshadow -Wstrict-prototypes -Wpointer-arith -pthread -DGGML_USE_ACCELERATE
I CXXFLAGS: -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wno-multichar -pthread
I LDFLAGS: -framework Accelerate
I CC: Apple clang version 14.0.3 (clang-1403.0.22.14.1)
I CXX: Apple clang version 14.0.3 (clang-1403.0.22.14.1)
make: Nothing to be done for `default'.
main: seed = 5
llama.cpp: loading model from ./models/7B/ggml-model-q4_0.bin
llama_model_load_internal: format = ggjt v1 (latest)
llama_model_load_internal: n_vocab = 32000
llama_model_load_internal: n_ctx = 2048
llama_model_load_internal: n_embd = 4096
llama_model_load_internal: n_mult = 256
llama_model_load_internal: n_head = 32
llama_model_load_internal: n_layer = 32
llama_model_load_internal: n_rot = 128
llama_model_load_internal: ftype = 2 (mostly Q4_0)
llama_model_load_internal: n_ff = 11008
llama_model_load_internal: n_parts = 1
llama_model_load_internal: model size = 7B
llama_model_load_internal: ggml ctx size = 59.11 KB
llama_model_load_internal: mem required = 5809.32 MB (+ 1026.00 MB per state)
llama_init_from_file: kv self size = 1024.00 MB
system_info: n_threads = 4 / 10 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 |
sampling: temp = 0.800000, top_k = 40, top_p = 0.950000, repeat_last_n = 64, repeat_penalty = 1.100000
generate: n_ctx = 2048, n_batch = 8, n_predict = 64, n_keep = 0
I believe the meaning of life is to learn, live and love.
There are three keys to living a happy life: Love what you do, Work with passionate people, Serve others who need your help.
Doing the things that matter most in life.
Being able to see beauty in everything around you.
The ability to
llama_print_timings: load time = 491.92 ms
llama_print_timings: sample time = 46.61 ms / 64 runs ( 0.73 ms per run)
llama_print_timings: prompt eval time = 482.84 ms / 8 tokens ( 60.36 ms per token)
llama_print_timings: eval time = 3435.16 ms / 63 runs ( 54.53 ms per run)
llama_print_timings: total time = 3974.66 ms
ggerganov Georgis-MBP ~/development/github/llama.cpp
20:21:06 ⚓ master-50cb666-8-g64a6a29 8⎘ $ make -j && ./main -m ./models/7B/ggml-model-q4_0c.bin -p "I believe the meaning of life is" -c 2048 -n 512 --ignore-eos -s 5 -n 64 -t 4
I llama.cpp build info:
I UNAME_S: Darwin
I UNAME_P: arm
I UNAME_M: arm64
I CFLAGS: -I. -O3 -DNDEBUG -std=c11 -fPIC -Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion -Wshadow -Wstrict-prototypes -Wpointer-arith -pthread -DGGML_USE_ACCELERATE
I CXXFLAGS: -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wno-multichar -pthread
I LDFLAGS: -framework Accelerate
I CC: Apple clang version 14.0.3 (clang-1403.0.22.14.1)
I CXX: Apple clang version 14.0.3 (clang-1403.0.22.14.1)
make: Nothing to be done for `default'.
main: seed = 5
llama.cpp: loading model from ./models/7B/ggml-model-q4_0c.bin
llama_model_load_internal: format = ggjt v1 (latest)
llama_model_load_internal: n_vocab = 32000
llama_model_load_internal: n_ctx = 2048
llama_model_load_internal: n_embd = 4096
llama_model_load_internal: n_mult = 256
llama_model_load_internal: n_head = 32
llama_model_load_internal: n_layer = 32
llama_model_load_internal: n_rot = 128
llama_model_load_internal: ftype = 7 (mostly Q4_0C)
llama_model_load_internal: n_ff = 11008
llama_model_load_internal: n_parts = 1
llama_model_load_internal: model size = 7B
llama_model_load_internal: ggml ctx size = 59.11 KB
llama_model_load_internal: mem required = 5809.32 MB (+ 1026.00 MB per state)
llama_init_from_file: kv self size = 1024.00 MB
system_info: n_threads = 4 / 10 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 |
sampling: temp = 0.800000, top_k = 40, top_p = 0.950000, repeat_last_n = 64, repeat_penalty = 1.100000
generate: n_ctx = 2048, n_batch = 8, n_predict = 64, n_keep = 0
I believe the meaning of life is to serve others with love and compassion. To share my time, energy and money to make this world a better place. That's where my passion comes from!
My parents gave me so much love, that I want to give it back to people who need it most. There are so many children around
llama_print_timings: load time = 552.26 ms
llama_print_timings: sample time = 46.86 ms / 64 runs ( 0.73 ms per run)
llama_print_timings: prompt eval time = 542.64 ms / 8 tokens ( 67.83 ms per token)
llama_print_timings: eval time = 3856.36 ms / 63 runs ( 61.21 ms per run)
llama_print_timings: total time = 4456.43 ms
ggerganov Georgis-MBP ~/development/github/llama.cpp
20:21:18 ⚓ master-50cb666-8-g64a6a29 8⎘ $ make -j && ./main -m ./models/7B/ggml-model-q4_0.bin -p "I believe the meaning of life is" -c 2048 -n 512 --ignore-eos -s 5 -n 64 -t 8
I llama.cpp build info:
I UNAME_S: Darwin
I UNAME_P: arm
I UNAME_M: arm64
I CFLAGS: -I. -O3 -DNDEBUG -std=c11 -fPIC -Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion -Wshadow -Wstrict-prototypes -Wpointer-arith -pthread -DGGML_USE_ACCELERATE
I CXXFLAGS: -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wno-multichar -pthread
I LDFLAGS: -framework Accelerate
I CC: Apple clang version 14.0.3 (clang-1403.0.22.14.1)
I CXX: Apple clang version 14.0.3 (clang-1403.0.22.14.1)
make: Nothing to be done for `default'.
main: seed = 5
llama.cpp: loading model from ./models/7B/ggml-model-q4_0.bin
llama_model_load_internal: format = ggjt v1 (latest)
llama_model_load_internal: n_vocab = 32000
llama_model_load_internal: n_ctx = 2048
llama_model_load_internal: n_embd = 4096
llama_model_load_internal: n_mult = 256
llama_model_load_internal: n_head = 32
llama_model_load_internal: n_layer = 32
llama_model_load_internal: n_rot = 128
llama_model_load_internal: ftype = 2 (mostly Q4_0)
llama_model_load_internal: n_ff = 11008
llama_model_load_internal: n_parts = 1
llama_model_load_internal: model size = 7B
llama_model_load_internal: ggml ctx size = 59.11 KB
llama_model_load_internal: mem required = 5809.32 MB (+ 1026.00 MB per state)
llama_init_from_file: kv self size = 1024.00 MB
system_info: n_threads = 8 / 10 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 |
sampling: temp = 0.800000, top_k = 40, top_p = 0.950000, repeat_last_n = 64, repeat_penalty = 1.100000
generate: n_ctx = 2048, n_batch = 8, n_predict = 64, n_keep = 0
I believe the meaning of life is to learn, live and love.
There are three keys to living a happy life: Love what you do, Work with passionate people, Serve others who need your help.
Doing the things that matter most in life.
Being able to see beauty in everything around you.
The ability to
llama_print_timings: load time = 350.05 ms
llama_print_timings: sample time = 46.93 ms / 64 runs ( 0.73 ms per run)
llama_print_timings: prompt eval time = 340.74 ms / 8 tokens ( 42.59 ms per token)
llama_print_timings: eval time = 2984.46 ms / 63 runs ( 47.37 ms per run)
llama_print_timings: total time = 3382.31 ms
ggerganov Georgis-MBP ~/development/github/llama.cpp
20:21:31 ⚓ master-50cb666-8-g64a6a29 8⎘ $ make -j && ./main -m ./models/7B/ggml-model-q4_0c.bin -p "I believe the meaning of life is" -c 2048 -n 512 --ignore-eos -s 5 -n 64 -t 8
I llama.cpp build info:
I UNAME_S: Darwin
I UNAME_P: arm
I UNAME_M: arm64
I CFLAGS: -I. -O3 -DNDEBUG -std=c11 -fPIC -Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion -Wshadow -Wstrict-prototypes -Wpointer-arith -pthread -DGGML_USE_ACCELERATE
I CXXFLAGS: -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wno-multichar -pthread
I LDFLAGS: -framework Accelerate
I CC: Apple clang version 14.0.3 (clang-1403.0.22.14.1)
I CXX: Apple clang version 14.0.3 (clang-1403.0.22.14.1)
make: Nothing to be done for `default'.
main: seed = 5
llama.cpp: loading model from ./models/7B/ggml-model-q4_0c.bin
llama_model_load_internal: format = ggjt v1 (latest)
llama_model_load_internal: n_vocab = 32000
llama_model_load_internal: n_ctx = 2048
llama_model_load_internal: n_embd = 4096
llama_model_load_internal: n_mult = 256
llama_model_load_internal: n_head = 32
llama_model_load_internal: n_layer = 32
llama_model_load_internal: n_rot = 128
llama_model_load_internal: ftype = 7 (mostly Q4_0C)
llama_model_load_internal: n_ff = 11008
llama_model_load_internal: n_parts = 1
llama_model_load_internal: model size = 7B
llama_model_load_internal: ggml ctx size = 59.11 KB
llama_model_load_internal: mem required = 5809.32 MB (+ 1026.00 MB per state)
llama_init_from_file: kv self size = 1024.00 MB
system_info: n_threads = 8 / 10 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 |
sampling: temp = 0.800000, top_k = 40, top_p = 0.950000, repeat_last_n = 64, repeat_penalty = 1.100000
generate: n_ctx = 2048, n_batch = 8, n_predict = 64, n_keep = 0
I believe the meaning of life is to serve others with love and compassion. To share my time, energy and money to make this world a better place. That's where my passion comes from!
My parents gave me so much love, that I want to give it back to people who need it most. There are so many children around
llama_print_timings: load time = 374.11 ms
llama_print_timings: sample time = 47.05 ms / 64 runs ( 0.74 ms per run)
llama_print_timings: prompt eval time = 364.90 ms / 8 tokens ( 45.61 ms per token)
llama_print_timings: eval time = 2990.87 ms / 63 runs ( 47.47 ms per run)
llama_print_timings: total time = 3412.83 ms
ggerganov Georgis-MBP ~/development/github/llama.cpp
20:21:40 ⚓ master-50cb666-8-g64a6a29 8⎘ $
Hm, sounds like it is sensitive to what prefetch distance to use then, that's unfortunate. Thanks for checking! It executes strictly less instructions, so I think performance should always be at least on par, but possibly the M1 just doesn't benefit from prefetching hints.
For reference, here are the times I saw on the Ampere Altra for different values of ahead
For eval:
And for prompt evaluation:
Somewhat related to this is the fact that Q8_0 as it is after #1083, #1109 now has two floats that go to waste for Q4_0 and Q4_2, at least for the AVX2 implementation. This makes quantization slower due to calculating unused values, and the vector dot product slower, as it has to churn through more memory.
We could define a new format, but this again makes the source code longer:
#define QK8_0 32
typedef struct {
float d; // delta
int8_t qs[QK8_0]; // quants
} block_q8_0;
#define QK8_1 32
typedef struct {
float d; // delta
float s0; // d * sum(qs[i]) low
float s1; // d * sum(qs[i]) high
int8_t qs[QK8_1]; // quants
} block_q8_1;
Edit: this was done in #1179
I started a branch for the same approach with the q4_2 data format: https://github.com/unbounded/llama.cpp/tree/continuous-q4_2c
For AMD64 w AVX512: q4_2:
llama_print_timings: prompt eval time = 899.29 ms / 6 tokens ( 149.88 ms per token)
llama_print_timings: eval time = 19050.60 ms / 99 runs ( 192.43 ms per run)
q4_2c:
llama_print_timings: prompt eval time = 564.52 ms / 6 tokens ( 94.09 ms per token)
llama_print_timings: eval time = 10912.85 ms / 99 runs ( 110.23 ms per run)
@unbounded
I am a bit torn about the proposed big block size of 4*32
- it will limit the application since not all models have rows divisible by 128.
I was wondering, if we proceed with 32
block size as in #1305, would it make sense for the AVX512 dot product implementation to memcpy 4 blocks into a local array in the proper 512-bit layout and fallback to the implementation here. The memcpy would obviously be an overhead, but maybe it wouldn't be too big and we still get most of the benefit from this implementation.
Would this work?
@ggerganov That is a thing I can test the performance impact of, but it will probably be a little while before I get around to it.
I don't see any reason we couldn't add a "rest" handler for non-block sizes - main disadvantage would be the padding and the bit of extra code. For now I'm looking to establish a baseline of what's achievable - possibly most of the wins can be achieved with less drastic changes like #1305. Next step here is probably testing some more variations on an M1 as well and see if there is performance improvements to be had there.
q4_2 timings on Ampere Altra, 7B
q4_2:
llama_print_timings: prompt eval time = 1481.07 ms / 6 tokens ( 246.85 ms per token)
llama_print_timings: eval time = 24994.51 ms / 99 runs ( 252.47 ms per run)
q4_2c:
llama_print_timings: prompt eval time = 1212.94 ms / 6 tokens ( 202.16 ms per token)
llama_print_timings: eval time = 21673.20 ms / 99 runs ( 218.92 ms per run)
Did some performance testing on M1 and updated the q4_2c branch: Now I see a performance win here as well:
q4_2:
llama_print_timings: prompt eval time = 592.93 ms / 6 tokens ( 98.82 ms per token)
llama_print_timings: eval time = 4458.55 ms / 50 runs ( 89.17 ms per run)
q4_2c:
llama_print_timings: prompt eval time = 486.42 ms / 6 tokens ( 81.07 ms per token)
llama_print_timings: eval time = 7533.07 ms / 99 runs ( 76.09 ms per run)
Some miscellaneous performance observations:
I saw no performance difference using 64-byte aligned loads with AVX-512.
Prefetching instructions give no benefit at all on M1 processors. On other CPUs they helped narrow the difference between prompt eval and prediction. Prediction time being much slower than prompt eval may indicate that prefetch instructions can improve performance.
Some very interesting work in #1256, the "super blocks" mentioned there are probably large enough to capture most of the benefit of this layout, if they are properly arranged. Also uses drastically fewer Float16 numbers, so there is less benefit in doing 2 or 4 F16->F32 conversions at once like we can here.