llama-cpp-python
llama-cpp-python copied to clipboard
Add beam search
Invoke by adding "beam_width": 2 (for example) to /v1/completions POST.
This PR will be moved out of Draft mode after https://github.com/ggerganov/llama.cpp/pull/2267 is merged. Closes #145 Closes #340 Closes #185
In the meantime, there is a question:
How does one specify --logits_all False
when invoked from the command line?
python3 -m llama_cpp.server --logits_all False
results in settings.logits_all=True
on startup.
results in
settings.logits_all=True
on startup.
you can pass False with an empty string --logits_all ''
but a better way would be to add either action='store_true'
or action=argparse.BooleanOptionalAction
when adding argument
@mattpulver Guess now you can un-draft this PR.
you can pass False with an empty string
--logits_all ''
but a better way would be to add eitheraction='store_true'
oraction=argparse.BooleanOptionalAction
when adding argument
When I add --logits_all ''
to python -m llama_cpp.server
it errors w/
__main__.py: error: argument --logits_all: invalid parse_bool_arg value: ''
In the meantime I changed the default setting for logits_all
from True to False: 3fce944
but I welcome any better suggestions. (I'm not sure where exactly to make the action='store_true'
suggestion.)
Testing
- Start the web server.
- Go to http://localhost:8000/docs (adjust port as needed)
- Open/click the
POST /v1/completions
panel. - Press the
Try it out
button. - Edit the Example Value json by adding
"beam_width": 2,
- Press
Execute
If you would like to see how the beams evolve and their probabilities, uncomment:
#print(f"\n\nCurrent beams (last_call={beams_state.last_call}):\n")
#for i in range(beams_state.n_beams):
# print(f"beams[{i}]", beam_view_to_string(callback_data.ctx,beams_state.beam_views[i]))
in llama_cpp/llama.py
.
Resolves #184
@mattpulver great work here, I'll review this and should have it merged this week.
Cheers
@abetlen Thanks. Perhaps the most intrusive integration change is changing the default value of the logits_all
command line parameter from True to False: 3fce944a9acb574407ee9ce6998530ea6f072915
This actually matches the default value of llama.cpp so in a broader sense this makes sense IMO but it may break some existing functionality.
@mattpulver just a quick update, I'm going to hold of merging this until after #771 because that's going to have some big impact on how we use the llama.cpp api internally in the Llama
class. Once that's in I'll take a look at dealing with the merge conflicts here. One thing to note, I won't change the default behaviour of logits_all, this would constitute a breaking change for a number of users so we need to find a better solution (automatially reload the model with logits_all=False) or inform the user that it has to be set to false for beam search.
If I wait a little longer, I consistently hit GGML_ASSERT: llama-cpp-python/vendor/llama.cpp/llama.cpp:5967: n_tokens <= n_batch
with this applied to the latest llama-cpp-python.
beam.tokens.size()
ends up at 513 here, which is one more than the batch size: https://github.com/ggerganov/llama.cpp/blob/master/llama.cpp#L7849-L7852
// beam is not at end-of-sentence, so branch with next top_k tokens.
if (!beam.tokens.empty()) {
llama_decode(ctx, llama_batch_get_one(beam.tokens.data(), beam.tokens.size(), n_past, 0));
}
backtrace (line numbers may not be accurate):
#4 0x00007f59ef28c43b in llama_decode_internal (lctx=..., batch=...) at /home/cebtenzzre/src/forks/llama-cpp-python/vendor/llama.cpp/llama.cpp:5970
#5 0x00007f59ef28c906 in llama_decode (ctx=<optimized out>, batch=<error reading variable: Cannot access memory at address 0x8>)
at /home/cebtenzzre/src/forks/llama-cpp-python/vendor/llama.cpp/llama.cpp:9781
#6 0x00007f59ef2a45df in llama_beam_search_data::fill_next_beams_by_top_probabilities (this=this@entry=0x7f5a0e1faca0, beam=...)
at /home/cebtenzzre/src/forks/llama-cpp-python/vendor/llama.cpp/llama.cpp:7951
#7 0x00007f59ef2a530f in llama_beam_search_data::loop (this=this@entry=0x7f5a0e1faca0, callback=callback@entry=0x7f5b2dfa3010,
callback_data=callback_data@entry=0x7f5a15c768d0) at /home/cebtenzzre/src/forks/llama-cpp-python/vendor/llama.cpp/llama.cpp:8032
#8 0x00007f59ef28e45c in llama_beam_search (ctx=0x7f59e9e34e30, callback=0x7f5b2dfa3010, callback_data=0x7f5a15c768d0, n_beams=<optimized out>,
n_past=<optimized out>, n_predict=<optimized out>) at /home/cebtenzzre/src/forks/llama-cpp-python/vendor/llama.cpp/llama.cpp:8072
I can reproduce it within 30 seconds or so with a 7B model on CUDA on commit 1a1c3dc418a3e55073676ea60b3c5b57117d3421. On earlier commits, it seems to just hang.
Fixing the EOS issue on my end does not resolve this.
Thanks @mattpulver for the PR (both here and in the llama.cpp repo)
Curious if abetlen, cebtenzzre, or someone else knows if/when this will get merged? I've found beam search extremely helpful for code generation, and would love to know if it'll be supported with the main llama-cpp-python
library in the near future, or if I should create a binary from Matt's PR instead.
Hey @rishsriv I'm still planning to merge this however I'm currently grinding through the batch processing support first as it requires a bunch of internal refactoring, after that I was planning on coming back in and merging this. Can't give an eta though, likely in the next few weeks if I had to guess.
Got it – thank you for the response! If there are things you think external contributors will be able to fix, please do open an issue and will be happy to help fix it.
If I wait a little longer, I consistently hit
GGML_ASSERT: llama-cpp-python/vendor/llama.cpp/llama.cpp:5967: n_tokens <= n_batch
with this applied to the latest llama-cpp-python.
Possibly related: ggerganov/llama.cpp#6664
Any update on this ?
Alright so I copied the changed into a local clone and it seems to be working, or running at least.
First it spend a crazy time on llama_cpp.llama_beam_search
then the token output is quite low. which makes the sampling time insanely high.
beam_width=1
HEAD MASTER :
llama_print_timings: load time = 154.96 ms
llama_print_timings: sample time = 1534.47 ms / 548 runs ( 2.80 ms per token, 357.13 tokens per second)
llama_print_timings: prompt eval time = 371.38 ms / 1661 tokens ( 0.22 ms per token, 4472.54 tokens per second)
llama_print_timings: eval time = 4816.73 ms / 547 runs ( 8.81 ms per token, 113.56 tokens per second)
This fork :
llama_print_timings: load time = 127.43 ms
llama_print_timings: sample time = 59429.62 ms / 1 runs (59429.62 ms per token, 0.02 tokens per second)
llama_print_timings: prompt eval time = 338.67 ms / 1661 tokens ( 0.20 ms per token, 4904.55 tokens per second)
llama_print_timings: eval time = 58430.05 ms / 6530 runs ( 8.95 ms per token, 111.76 tokens per second)
59429.62 ms per token