llama.cpp
llama.cpp copied to clipboard
Server: enable lookup decoding
This PR aims to enable lookup decoding for the llama.cpp server in the same way as it is used in examples/lookup
, see https://github.com/ggerganov/llama.cpp/pull/5479 . To recapitulate, the implementation tries to guess the next few tokens that will be generated using simple text statistics. I think the current implementation works but it is difficult to properly benchmark. The intended way for it to work is:
- Start with empty context cache, empty dynamic cache (unless the user provides one), and static cache loaded from file.
- When generating tokens, try to continue with context cache, validate with static cache.
- If that fails, try to continue with dynamic cache, validate with static cache.
- If that fails, try to continue with static cache.
- When new tokens are generated, use them to update the context cache.
- When a generation is finished, update the dynamic cache with the context cache, then empty the context cache.
- On server shutdown, save dynamic cache to file if the user provided a path.
These are the results I get from examples/server/bench.py
using an RTX 4090 and various static lookup caches and an initially empty dynamic cache:
Model | Static lookup cache | Iterations ~~master~~ PR --draft 0 | Iterations PR --draft 5 | Speedup |
---|---|---|---|---|
Phi-2 3b q4_0 | None | 274 | 365 | 1.33 |
Phi-2 3b q4_0 | Wikitext 103 | 274 | 361 | 1.32 |
Phi-2 3b q4_0 | Mistral 1.64e8 | 274 | 354 | 1.29 |
LLaMA 2 7b q4_0 | None | 148 | 256 | 1.73 |
LLaMA 2 7b q4_0 | Wikitext 103 | 148 | 255 | 1.72 |
LLaMA 2 7b q4_0 | Mistral 1.64e8 | 148 | 255 | 1.72 |
Edit: the table was labeled incorrectly. The speedup was not relative to master but relative to --draft 0
which included the overhead for no benefit.
It does seem to provide a speedup but adding a static lookup cache does not seem to help (the caches are created either from Wikitext 103 or from 164 million tokens generated with Mistral q8_0). Assuming there are no bugs, what I think is happening is that the dataset for the benchmark (see server bench README) is very repetitive so using a static cache pulls the drafts away from these very repetitive patterns and reduces the speed. Also for Phi-2 in particular I think that I simply don't have enough input data for the static cache to get sufficiently precise text statistics (since it has a larger vocabulary size). Regarding the latter, I recently built a machine with 6x RTX 4090 so I think I will be able to significantly scale up the rate at which I can produce synthetic text (I was previously using 3x P40 and 1x RX 6800).
In this PR I also changed the interface of llama_ngram_cache_load
to be more in line with the rest of llama.cpp; I'll maybe change the interface of some of the other functions as well.
Also: is it somehow possible to retrieve the tokens that were previously fed to the model? I'm currently manually tracking this in server.cpp
but this adds the potential for error.
Great work. As we discussed previously, servers' test coverage matters, and adding a new scenario in the test framework is mandatory.
adding a new scenario in the test framework is mandatory.
Are there already any tests that assert correctness for the server? I didn't see any so as part of this implementation I would try to add some.
Are there already any tests that assert correctness for the server? I didn't see any so as part of this implementation I would try to add some.
https://github.com/ggerganov/llama.cpp/tree/master/examples/server/tests
While writing tests I'm noticing that when using > 1 slots the results for a given seed are not consistent on master. @phymbert is this a known problem?
I was not aware, but this is not asserted in the parallel test suite AFAIK.
Also, I recall that each architecture generates different results.
A research paper studying this exact technique was recently published and suggested for integration in an issue https://github.com/ggerganov/llama.cpp/issues/6813 I have been looking at the current implementation and trying to make it match the implementation in the paper.
- The short version of it is that there are no
static
ordynamic
N-gram caches, only ones like you callcontext
, generated on the fly. - They use what they call a multi level N-gram. There are multiple caches. Each cache only considers N-grams of length N. When querying the multi level N-gram, the module with the longest N-grams that match the token suffix is the one used.
- They then generate K tokens using this multi level N-gram as a proposal to be validated by the LLM.
- Based on their ablation study, they suggest N = 5 and K = 7.
I haven't spent enough time reading ngram-cache.cpp
and friends to tell how it works and how it differs, besides the persistence.
I will post results if there are any. I'm new to this codebase and a bit rusty at C++. Adjust your expectations accordingly.
Thanks for the input. I saw the paper but didn't yet get around to reading it.
They use what they call a multi level N-gram. There are multiple caches. Each cache only considers N-grams of length N. When querying the multi level N-gram, the module with the longest N-grams that match the token suffix is the one used.
In practice my implementation also uses N-grams of varying sizes. A llama_ngram_cache
can contain N-grams of multiple sizes simultaneously; it's just easier to bundle them together. The context and dynamic caches contain 1-grams, 2-grams, 3-grams, and 4-grams. The static caches only contain 2-grams (given enough input data 3-grams or 4-grams should also be viable).
I forgot: if you want to play around with the llama.cpp implementation, take a look at lookup-stats
. It has the same interface as perplexity
and can be used to estimate how many tokens tokens would be predicted on a pre-generated text (so you don't have to actually evaluate the model).
How much does this PR increase token generation? As far I am aware #5479 had rather tiny speedup. And when do you think this PR will be ready to be merged?
How much does this PR increase token generation? As far I am aware https://github.com/ggerganov/llama.cpp/pull/5479 had rather tiny speedup.
Should be something like 1.1-1.4 for natural language with an essentially empty context. For source code or summarization it's going to be a lot more. The numbers in the OP are indicative of the long-term speedup using similar prompts once the dynamic cache fills up.
And when do you think this PR will be ready to be merged?
I'm aiming for the end of the week.
Does this allow us to create a static cache during inference?
No, use lookup-create
for that. I'll upload the caches that I've been using myself before I merge this PR.
Functionally I think everything is in order now. Unfortunately I think that it's currently not possible to get bit-for-bit identical results with lookup decoding since the results seem to change slightly when the batch size is varied, see https://github.com/ggerganov/llama.cpp/pull/6950 . For this reason there are no automated tests for lookup decoding that assert that the results do not change (because they do).
📈 llama.cpp server for bench-server-baseline on Standard_NC4as_T4_v3 for phi-2
-q4_0
: 484 iterations 🚀
Expand details for performance related PR only
- Concurrent users: 8, duration: 10m
- HTTP request : avg=9681.58ms p(95)=24109.08ms fails=, finish reason: stop=429 truncated=55
- Prompt processing (pp): avg=116.51tk/s p(95)=508.02tk/s
- Token generation (tg): avg=28.36tk/s p(95)=54.29tk/s
- ggml-org/models/phi-2/ggml-model-q4_0.gguf parallel=8 ctx-size=16384 ngl=33 batch-size=2048 ubatch-size=256 pp=1024 pp+tg=2048 branch=server-ngram-4 commit=71c98cc3bd4afd19a813b89197b93a29d8cc0e86
More
---
config:
xyChart:
titleFontSize: 12
width: 900
height: 600
themeVariables:
xyChart:
titleColor: "#000000"
---
xychart-beta
title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
duration=10m 484 iterations"
y-axis "llamacpp:prompt_tokens_seconds"
x-axis "llamacpp:prompt_tokens_seconds" 1715552714 --> 1715553340
line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 664.96, 664.96, 664.96, 664.96, 664.96, 554.92, 554.92, 554.92, 554.92, 554.92, 569.16, 569.16, 569.16, 569.16, 569.16, 597.17, 597.17, 597.17, 597.17, 597.17, 662.36, 662.36, 662.36, 662.36, 662.36, 684.79, 684.79, 684.79, 684.79, 684.79, 684.96, 684.96, 684.96, 684.96, 684.96, 689.13, 689.13, 689.13, 689.13, 689.13, 700.63, 700.63, 700.63, 700.63, 700.63, 701.78, 701.78, 701.78, 701.78, 701.78, 717.84, 717.84, 717.84, 717.84, 717.84, 721.95, 721.95, 721.95, 721.95, 721.95, 746.06, 746.06, 746.06, 746.06, 746.06, 791.89, 791.89, 791.89, 791.89, 791.89, 801.42, 801.42, 801.42, 801.42, 801.42, 728.36, 728.36, 728.36, 728.36, 728.36, 730.82, 730.82, 730.82, 730.82, 730.82, 730.84, 730.84, 730.84, 730.84, 730.84, 729.46, 729.46, 729.46, 729.46, 729.46, 740.83, 740.83, 740.83, 740.83, 740.83, 743.75, 743.75, 743.75, 743.75, 743.75, 744.05, 744.05, 744.05, 744.05, 744.05, 751.84, 751.84, 751.84, 751.84, 751.84, 752.26, 752.26, 752.26, 752.26, 752.26, 758.02, 758.02, 758.02, 758.02, 758.02, 762.14, 762.14, 762.14, 762.14, 762.14, 762.87, 762.87, 762.87, 762.87, 762.87, 764.75, 764.75, 764.75, 764.75, 764.75, 766.72, 766.72, 766.72, 766.72, 766.72, 781.72, 781.72, 781.72, 781.72, 781.72, 778.63, 778.63, 778.63, 778.63, 778.63, 779.73, 779.73, 779.73, 779.73, 779.73, 778.17, 778.17, 778.17, 778.17, 778.17, 778.67, 778.67, 778.67, 778.67, 778.67, 783.54, 783.54, 783.54, 783.54, 783.54, 785.1, 785.1, 785.1, 785.1, 785.1, 783.67, 783.67, 783.67, 783.67, 783.67, 786.25, 786.25, 786.25, 786.25, 786.25, 783.81, 783.81, 783.81, 783.81, 783.81, 792.32, 792.32, 792.32, 792.32, 792.32, 801.87, 801.87, 801.87, 801.87, 801.87, 802.64, 802.64, 802.64, 802.64, 802.64, 801.62, 801.62, 801.62, 801.62, 801.62, 801.58, 801.58, 801.58, 801.58, 801.58, 804.53, 804.53, 804.53, 804.53, 804.53, 806.0, 806.0, 806.0, 806.0, 806.0, 796.97, 796.97, 796.97, 796.97, 796.97, 786.16, 786.16, 786.16, 786.16, 786.16, 748.27, 748.27, 748.27, 748.27, 748.27, 747.44, 747.44, 747.44, 747.44, 747.44, 746.86, 746.86, 746.86, 746.86, 746.86, 751.51, 751.51, 751.51, 751.51, 751.51, 750.58, 750.58, 750.58, 750.58, 750.58, 751.4, 751.4, 751.4, 751.4, 751.4, 757.75, 757.75, 757.75, 757.75, 757.75, 757.78, 757.78, 757.78, 757.78, 757.78, 762.74, 762.74, 762.74, 762.74, 762.74, 767.18, 767.18, 767.18, 767.18, 767.18, 766.77, 766.77, 766.77, 766.77, 766.77, 772.08, 772.08, 772.08, 772.08]
More
---
config:
xyChart:
titleFontSize: 12
width: 900
height: 600
themeVariables:
xyChart:
titleColor: "#000000"
---
xychart-beta
title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
duration=10m 484 iterations"
y-axis "llamacpp:predicted_tokens_seconds"
x-axis "llamacpp:predicted_tokens_seconds" 1715552714 --> 1715553340
line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 29.58, 29.58, 29.58, 29.58, 29.58, 26.11, 26.11, 26.11, 26.11, 26.11, 25.16, 25.16, 25.16, 25.16, 25.16, 25.5, 25.5, 25.5, 25.5, 25.5, 26.79, 26.79, 26.79, 26.79, 26.79, 26.66, 26.66, 26.66, 26.66, 26.66, 26.73, 26.73, 26.73, 26.73, 26.73, 26.87, 26.87, 26.87, 26.87, 26.87, 27.42, 27.42, 27.42, 27.42, 27.42, 27.42, 27.42, 27.42, 27.42, 27.42, 27.78, 27.78, 27.78, 27.78, 27.78, 27.4, 27.4, 27.4, 27.4, 27.4, 27.19, 27.19, 27.19, 27.19, 27.19, 26.8, 26.8, 26.8, 26.8, 26.8, 26.91, 26.91, 26.91, 26.91, 26.91, 26.54, 26.54, 26.54, 26.54, 26.54, 26.52, 26.52, 26.52, 26.52, 26.52, 26.44, 26.44, 26.44, 26.44, 26.44, 26.38, 26.38, 26.38, 26.38, 26.38, 26.17, 26.17, 26.17, 26.17, 26.17, 25.61, 25.61, 25.61, 25.61, 25.61, 25.57, 25.57, 25.57, 25.57, 25.57, 25.48, 25.48, 25.48, 25.48, 25.48, 25.23, 25.23, 25.23, 25.23, 25.23, 25.42, 25.42, 25.42, 25.42, 25.42, 25.38, 25.38, 25.38, 25.38, 25.38, 25.36, 25.36, 25.36, 25.36, 25.36, 25.44, 25.44, 25.44, 25.44, 25.44, 25.4, 25.4, 25.4, 25.4, 25.4, 25.4, 25.4, 25.4, 25.4, 25.4, 24.75, 24.75, 24.75, 24.75, 24.75, 24.39, 24.39, 24.39, 24.39, 24.39, 24.38, 24.38, 24.38, 24.38, 24.38, 24.4, 24.4, 24.4, 24.4, 24.4, 24.45, 24.45, 24.45, 24.45, 24.45, 24.62, 24.62, 24.62, 24.62, 24.62, 24.74, 24.74, 24.74, 24.74, 24.74, 24.84, 24.84, 24.84, 24.84, 24.84, 24.91, 24.91, 24.91, 24.91, 24.91, 24.87, 24.87, 24.87, 24.87, 24.87, 24.81, 24.81, 24.81, 24.81, 24.81, 24.51, 24.51, 24.51, 24.51, 24.51, 24.61, 24.61, 24.61, 24.61, 24.61, 24.65, 24.65, 24.65, 24.65, 24.65, 24.79, 24.79, 24.79, 24.79, 24.79, 24.83, 24.83, 24.83, 24.83, 24.83, 24.88, 24.88, 24.88, 24.88, 24.88, 24.81, 24.81, 24.81, 24.81, 24.81, 24.64, 24.64, 24.64, 24.64, 24.64, 24.52, 24.52, 24.52, 24.52, 24.52, 24.09, 24.09, 24.09, 24.09, 24.09, 23.89, 23.89, 23.89, 23.89, 23.89, 23.83, 23.83, 23.83, 23.83, 23.83, 23.94, 23.94, 23.94, 23.94, 23.94, 23.99, 23.99, 23.99, 23.99, 23.99, 24.14, 24.14, 24.14, 24.14, 24.14, 24.19, 24.19, 24.19, 24.19, 24.19, 24.19, 24.19, 24.19, 24.19, 24.19, 24.23, 24.23, 24.23, 24.23, 24.23, 24.09, 24.09, 24.09, 24.09]
Details
More
---
config:
xyChart:
titleFontSize: 12
width: 900
height: 600
themeVariables:
xyChart:
titleColor: "#000000"
---
xychart-beta
title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
duration=10m 484 iterations"
y-axis "llamacpp:kv_cache_usage_ratio"
x-axis "llamacpp:kv_cache_usage_ratio" 1715552714 --> 1715553340
line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.24, 0.24, 0.24, 0.24, 0.24, 0.38, 0.38, 0.38, 0.38, 0.38, 0.19, 0.19, 0.19, 0.19, 0.19, 0.11, 0.11, 0.11, 0.11, 0.11, 0.21, 0.21, 0.21, 0.21, 0.21, 0.14, 0.14, 0.14, 0.14, 0.14, 0.11, 0.11, 0.11, 0.11, 0.11, 0.12, 0.12, 0.12, 0.12, 0.12, 0.15, 0.15, 0.15, 0.15, 0.15, 0.13, 0.13, 0.13, 0.13, 0.13, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.24, 0.24, 0.24, 0.24, 0.24, 0.16, 0.16, 0.16, 0.16, 0.16, 0.39, 0.39, 0.39, 0.39, 0.39, 0.12, 0.12, 0.12, 0.12, 0.12, 0.16, 0.16, 0.16, 0.16, 0.16, 0.13, 0.13, 0.13, 0.13, 0.13, 0.24, 0.24, 0.24, 0.24, 0.24, 0.24, 0.24, 0.24, 0.24, 0.24, 0.21, 0.21, 0.21, 0.21, 0.21, 0.2, 0.2, 0.2, 0.2, 0.2, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.15, 0.15, 0.15, 0.15, 0.15, 0.16, 0.16, 0.16, 0.16, 0.16, 0.38, 0.38, 0.38, 0.38, 0.38, 0.2, 0.2, 0.2, 0.2, 0.2, 0.16, 0.16, 0.16, 0.16, 0.16, 0.12, 0.12, 0.12, 0.12, 0.12, 0.13, 0.13, 0.13, 0.13, 0.13, 0.14, 0.14, 0.14, 0.14, 0.14, 0.09, 0.09, 0.09, 0.09, 0.09, 0.14, 0.14, 0.14, 0.14, 0.14, 0.11, 0.11, 0.11, 0.11, 0.11, 0.27, 0.27, 0.27, 0.27, 0.27, 0.15, 0.15, 0.15, 0.15, 0.15, 0.34, 0.34, 0.34, 0.34, 0.34, 0.13, 0.13, 0.13, 0.13, 0.13, 0.14, 0.14, 0.14, 0.14, 0.14, 0.09, 0.09, 0.09, 0.09, 0.09, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.31, 0.31, 0.31, 0.31, 0.31, 0.49, 0.49, 0.49, 0.49, 0.49, 0.53, 0.53, 0.53, 0.53, 0.53, 0.46, 0.46, 0.46, 0.46, 0.46, 0.4, 0.4, 0.4, 0.4, 0.4, 0.1, 0.1, 0.1, 0.1, 0.1, 0.18, 0.18, 0.18, 0.18, 0.18, 0.1, 0.1, 0.1, 0.1, 0.1, 0.19, 0.19, 0.19, 0.19, 0.19, 0.12, 0.12, 0.12, 0.12, 0.12, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.13, 0.13, 0.13, 0.13, 0.13, 0.23, 0.23, 0.23, 0.23]
More
---
config:
xyChart:
titleFontSize: 12
width: 900
height: 600
themeVariables:
xyChart:
titleColor: "#000000"
---
xychart-beta
title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
duration=10m 484 iterations"
y-axis "llamacpp:requests_processing"
x-axis "llamacpp:requests_processing" 1715552714 --> 1715553340
line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 3.0, 3.0, 3.0, 3.0, 3.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 8.0, 8.0, 8.0, 8.0, 8.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0, 2.0, 2.0, 2.0, 2.0, 2.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 3.0, 3.0, 3.0, 3.0, 3.0, 6.0, 6.0, 6.0, 6.0, 6.0, 1.0, 1.0, 1.0, 1.0, 1.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 3.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 2.0, 2.0, 2.0, 2.0, 2.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 1.0, 1.0, 1.0, 1.0, 1.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0]
I'm very sorry but it seems the numbers that I previously reported were incorrect. The speed I reported for "master" was actually the speed for this PR with --draft 0
. However, this means that the numbers were still including the overhead associated with the lookup caches which is quite significant. These are the correct numbers for the most recent version:
Model | Static lookup cache | Slots | Iterations master | Iterations PR --draft 0 | Iterations PR --draft 5 | Speedup vs. master | Speedup vs. --draft 0 |
---|---|---|---|---|---|---|---|
Phi-2 3b q4_0 | None | 1 | 549 | 274 | 363 | 0.66 | 1.32 |
Phi-2 3b q4_0 | None | 2 | 947 | 455 | 599 | 0.63 | 1.32 |
Phi-2 3b q4_0 | None | 4 | 1465 | 704 | 797 | 0.54 | 1.13 |
Phi-2 3b q4_0 | None | 8 | 1856 | 855 | 900 | 0.48 | 1.05 |
For Phi-2 on an RTX 4090 there is a regression relative to master because it is quite fast so the constant overhead per token is too large relative to the speedup. I'll investigate performance for larger models/slower hardware.
Performance for LLaMA 3 70 on 3x RTX 4090 is looking much better:
Model | Static lookup cache | Slots | Iterations master | Iterations PR --draft 5 | Speedup vs. master |
---|---|---|---|---|---|
LLaMA 3 70b q4_K_M | None | 1 | 24 | 44 | 1.83 |
LLaMA 3 70b q4_K_M | WT 103 | 1 | 24 | 42 | 1.75 |
Regarding performance, it seems your hashes for the lookup table are of low quality. std::hash<llama_token>
is the same as std::hash<int32_t>
, which just returns the identity of the token. Also the standard containers are known to be not performing best in class, but that's a different issue. :)
edit: this is a wold class article on how to make a fast lookup table, including a pretty neat hash function https://probablydance.com/2018/06/16/fibonacci-hashing-the-optimization-that-the-world-forgot-or-a-better-alternative-to-integer-modulo/
edit2: the way the hashes are combined in the ngram means that for the 64bit, only 32bit have any entropy at all. A better hash would probably fix this, but hashes are often combined with an extra shift or another multiplication.
Thank you for the high-quality post. I definitely agree that the hashing is suboptimal, my main concern for now is to get something that works at all, and to also implement tests that assert this.
Prior to reading the hashing function blog post I wrote a simple implementation that just uses bit shifts and xors but that already results in much better performance:
Model | Static lookup cache | Slots | Iterations master | Iterations PR --draft 5 | Speedup vs. master |
---|---|---|---|---|---|
Phi-2 3b q4_0 | None | 1 | 549 | 634 | 1.15 |
Phi-2 3b q4_0 | None | 2 | 947 | 1113 | 1.18 |
Phi-2 3b q4_0 | None | 4 | 1465 | 1572 | 1.07 |
Phi-2 3b q4_0 | None | 8 | 1856 | 1790 | 0.96 |
Phi-2 3b q4_0 | WT 103 | 1 | 549 | 643 | 1.17 |
Phi-2 3b q4_0 | WT 103 | 2 | 947 | 1098 | 1.16 |
Phi-2 3b q4_0 | WT 103 | 4 | 1465 | 1549 | 1.06 |
Phi-2 3b q4_0 | WT 103 | 8 | 1856 | 1766 | 0.95 |
Thanks for improving performance of llama.cpp. It seems that you were correct: lookup decoding improves speed, but adds constant overhead. So larger models have greater benefit from it. How does performance looks like for 7-13b models, in slower GPU and CPU-only backends?
I think the model and prompt will be a bigger factor than the hardware as long as the hashing is fast enough. These are some numbers I get on my Epyc 7742 CPU with 8x 3200 MHz Micron DIMMs:
Model | Static lookup cache | Slots | Iterations master | Iterations PR --draft 5 | Speedup vs. master |
---|---|---|---|---|---|
Phi-2 3b q4_0 | None | 1 | 103 | 119 | 1.16 |
LLaMA 3 70b q4_K_M | None | 1 | 3 | 5 | 1.67 |
Note that the comparatively large speedups with LLaMA 3 70b are likely a product of heavy repetition since I am using the base model.
I've added a test for asserting that lookup decoding produces correct results. The sequences are the same for temperature 0 though the results are not going to be bit-for-bit identical. I've also investigated the performance for LLaMA 3 Instruct in more detail:
Results
Model | GPU | Static lookup cache | Slots | Seed | Previous runs | Iterations master | Iterations PR --draft 5 | Speedup vs. master |
---|---|---|---|---|---|---|---|---|
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | None | 1 | 42 | 0 | 166 | 182 | 1.10 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | None | 2 | 42 | 0 | 278 | 283 | 1.02 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | None | 4 | 42 | 0 | 429 | 367 | 0.86 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | None | 8 | 42 | 0 | 531 | 407 | 0.77 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | WT 103 | 1 | 42 | 0 | 166 | 183 | 1.10 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | WT 103 | 2 | 42 | 0 | 278 | 282 | 1.01 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | WT 103 | 4 | 42 | 0 | 429 | 355 | 0.83 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | WT 103 | 1 | -1 | 0 | 166 | 182 | 1.10 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | WT 103 | 1 | -1 | 1 | 166 | 186 | 1.12 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | WT 103 | 1 | -1 | 2 | 166 | 204 | 1.23 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | WT 103 | 1 | -1 | 3 | 166 | 208 | 1.25 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | WT 103 | 1 | -1 | 4 | 166 | 215 | 1.30 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | WT 103 | 1 | -1 | 5 | 166 | 219 | 1.32 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | WT 103 | 1 | -1 | 6 | 166 | 222 | 1.34 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | WT 103 | 1 | -1 | 7 | 166 | 222 | 1.34 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | WT 103 | 2 | -1 | 0 | 278 | 285 | 1.03 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | WT 103 | 2 | -1 | 1 | 278 | 283 | 1.02 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | WT 103 | 2 | -1 | 2 | 278 | 309 | 1.11 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | WT 103 | 2 | -1 | 3 | 278 | 315 | 1.13 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | WT 103 | 2 | -1 | 4 | 278 | 324 | 1.17 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | WT 103 | 2 | -1 | 5 | 278 | 326 | 1.17 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | WT 103 | 2 | -1 | 6 | 278 | 329 | 1.18 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | WT 103 | 2 | -1 | 7 | 278 | 333 | 1.20 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | WT 103 | 4 | -1 | 0 | 429 | 363 | 0.85 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | WT 103 | 4 | -1 | 1 | 429 | 353 | 0.82 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | WT 103 | 4 | -1 | 2 | 429 | 370 | 0.86 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | WT 103 | 4 | -1 | 3 | 429 | 378 | 0.88 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | WT 103 | 4 | -1 | 4 | 429 | 378 | 0.88 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | WT 103 | 4 | -1 | 5 | 429 | 383 | 0.89 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | WT 103 | 4 | -1 | 6 | 429 | 383 | 0.89 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | WT 103 | 4 | -1 | 7 | 429 | 388 | 0.90 |
LLaMA 3 Instruct 70b q4_K_M | 2x RTX 4090 | None | 1 | 42 | 0 | 28 | 31 | 1.11 |
LLaMA 3 Instruct 70b q4_K_M | 2x RTX 4090 | None | 2 | 42 | 0 | 55 | 57 | 1.04 |
LLaMA 3 Instruct 70b q4_K_M | 2x RTX 4090 | None | 4 | 42 | 0 | 96 | 69 | 0.72 |
LLaMA 3 Instruct 70b q4_K_M | 2x RTX 4090 | None | 8 | 42 | 0 | 120 | OOM | ??? |
The speedup between LLaMA 3 instruct 8b and 70b seems to be very similar. The current implementation is only faster for small numbers of slots since there is comparatively less benefit for adding more tokens to the batch if you're already at 8 tokens per batch without any speculative decoding. Successive, similar runs with different seeds but a carried over dynamic cache result in increasing performance over time, for a single slot the 8th run was ~1.2x faster than the first one.
From my side I would consider this PR ready to be merged if one last issue is resolved: whether n-gram lookup should be enabled or disabled by default. The default for the number of slots is 1 and for that case it is faster. However, due to the varying batch size it also causes nondeterministic results. I personally would tend more towards having n-gram lookup be disabled by default but do not have a strong opinion on it.
@JohannesGaessler can I convince you to quickly add an overload for std::hash<llama_token_t> and do a quick comparison? While the shift in the ngram hash stuffles the hash a bit, it probably is still pretty bad. + this is a very small change.
I'm not sure what you mean by overload but I'm happy to test suggested alternatives.
Try the following:
diff --git a/common/ngram-cache.h b/common/ngram-cache.h
index 6575ea05..df420e1f 100644
--- a/common/ngram-cache.h
+++ b/common/ngram-cache.h
@@ -37,13 +37,18 @@ struct llama_ngram {
}
};
};
+struct llama_token_hash_function {
+ size_t operator()(const llama_token token) const {
+ return token * 11400714819323198485llu;
+ }
+};
+
struct llama_ngram_hash_function {
size_t operator()(const llama_ngram & ngram) const {
- size_t hash = ngram.tokens[0];
+ size_t hash = llama_token_hash_function{}(ngram.tokens[0]);
for (int i = 1; i < LLAMA_NGRAM_MAX; ++i) {
- hash <<= 15;
- hash ^= ngram.tokens[i];
+ hash ^= llama_token_hash_function{}(ngram.tokens[i]);
}
return hash;
@@ -51,7 +56,7 @@ struct llama_ngram_hash_function {
};
I went the route you went instead and used another callable type. Notes:
- I removed the shift, since it discards a lot in this case.
- Since ngrams hash is always over the
#define LLAMA_NGRAM_MAX 4
and the unused are-1
, you actually shift any entropy away, which collapses them to the same hash again. - The multiply is probably enough.
Please test :)
I took over the Fibonacci hash implementation. For LLaMA 3 q4_K_M on an RTX 4090 it's maybe a ~1% end-to-end speedup.
Results
Model | GPU | Static lookup cache | Slots | Seed | Previous runs | Iterations master | Iterations PR --draft 5 | Speedup vs. master |
---|---|---|---|---|---|---|---|---|
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | None | 1 | 42 | 0 | 166 | 183 | 1.10 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | None | 2 | 42 | 0 | 278 | 285 | 1.03 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | None | 4 | 42 | 0 | 429 | 365 | 0.85 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | None | 8 | 42 | 0 | 531 | 417 | 0.79 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | WT 103 | 1 | 42 | 0 | 166 | 183 | 1.10 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | WT 103 | 2 | 42 | 0 | 278 | 284 | 1.02 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | WT 103 | 4 | 42 | 0 | 429 | 360 | 0.84 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | WT 103 | 1 | -1 | 0 | 166 | 183 | 1.10 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | WT 103 | 1 | -1 | 1 | 166 | 184 | 1.11 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | WT 103 | 1 | -1 | 2 | 166 | 206 | 1.24 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | WT 103 | 1 | -1 | 3 | 166 | 212 | 1.28 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | WT 103 | 1 | -1 | 4 | 166 | 215 | 1.30 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | WT 103 | 1 | -1 | 5 | 166 | 219 | 1.32 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | WT 103 | 1 | -1 | 6 | 166 | 221 | 1.33 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | WT 103 | 1 | -1 | 7 | 166 | 223 | 1.34 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | WT 103 | 2 | -1 | 0 | 278 | 288 | 1.04 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | WT 103 | 2 | -1 | 1 | 278 | 283 | 1.02 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | WT 103 | 2 | -1 | 2 | 278 | 308 | 1.11 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | WT 103 | 2 | -1 | 3 | 278 | 315 | 1.13 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | WT 103 | 2 | -1 | 4 | 278 | 322 | 1.16 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | WT 103 | 2 | -1 | 5 | 278 | 322 | 1.16 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | WT 103 | 2 | -1 | 6 | 278 | 329 | 1.18 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | WT 103 | 2 | -1 | 7 | 278 | 330 | 1.19 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | WT 103 | 4 | -1 | 0 | 429 | 358 | 0.83 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | WT 103 | 4 | -1 | 1 | 429 | 353 | 0.82 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | WT 103 | 4 | -1 | 2 | 429 | 372 | 0.87 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | WT 103 | 4 | -1 | 3 | 429 | 377 | 0.88 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | WT 103 | 4 | -1 | 4 | 429 | 380 | 0.89 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | WT 103 | 4 | -1 | 5 | 429 | 383 | 0.89 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | WT 103 | 4 | -1 | 6 | 429 | 386 | 0.90 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | WT 103 | 4 | -1 | 7 | 429 | 389 | 0.91 |
Different STL implementations will perform differently here.
I re-tested the performance on 1x RTX 4090 with CUDA graphs but against my expectations I am seeing virtually no performance difference compared to before:
Model | GPU | Static lookup cache | Slots | Seed | Previous runs | Iterations master | Iterations PR --draft 5 | Speedup vs. master |
---|---|---|---|---|---|---|---|---|
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | None | 1 | 42 | 0 | 167 | 183 | 1.10 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | None | 2 | 42 | 0 | 277 | 284 | 1.03 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | None | 4 | 42 | 0 | 426 | 363 | 0.85 |
LLaMA 3 Instruct 8b q4_K_M | 1x RTX 4090 | None | 8 | 42 | 0 | 540 | 417 | 0.77 |
I re-tested the performance on 1x RTX 4090 with CUDA graphs but against my expectations I am seeing virtually no performance difference compared to before:
Model GPU Static lookup cache Slots Seed Previous runs Iterations master Iterations PR --draft 5 Speedup vs. master LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 None 1 42 0 167 183 1.10 LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 None 2 42 0 277 284 1.03 LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 None 4 42 0 426 363 0.85 LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 None 8 42 0 540 417 0.77
a quick question How does number of draft affect the performance? I saw you have many branch of different draft.
The numbers for the server-ngram
branches on my repository are just the numbers I use internally to keep my branches apart. Just use the branch I'm using for this PR.