llama.cpp icon indicating copy to clipboard operation
llama.cpp copied to clipboard

llama : add grammar-based sampling

Open ejones opened this issue 1 year ago • 19 comments

Inspired by #1397 and grantslatton's CFG work, this adds an API that takes a serialized context-free grammar to guide and constrain sampling. Also adds a sample Backus-Naur form (BNF)-like syntax in main for specifying a grammar for generations.

Testing

(M2 Max, 30B)

Chess
 % ./main -m $LLAMA_30B_Q4_0 -n 32 -p $'A good game:\n\n' --grammar-file grammars/chess.gbnf
main: build = 645 (fd0eb66)
main: seed  = 1686285871
llama.cpp: loading model from /Users/evan/llama-models/30B/ggml-model-q4_0.bin
llama_model_load_internal: format     = ggjt v3 (latest)
llama_model_load_internal: n_vocab    = 32000
llama_model_load_internal: n_ctx      = 512
llama_model_load_internal: n_embd     = 6656
llama_model_load_internal: n_mult     = 256
llama_model_load_internal: n_head     = 52
llama_model_load_internal: n_layer    = 60
llama_model_load_internal: n_rot      = 128
llama_model_load_internal: ftype      = 2 (mostly Q4_0)
llama_model_load_internal: n_ff       = 17920
llama_model_load_internal: n_parts    = 1
llama_model_load_internal: model size = 30B
llama_model_load_internal: ggml ctx size =    0.13 MB
llama_model_load_internal: mem required  = 19756.66 MB (+ 3124.00 MB per state)
.
llama_init_from_file: kv self size  =  780.00 MB

system_info: n_threads = 8 / 12 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 | 
sampling: repeat_last_n = 64, repeat_penalty = 1.100000, presence_penalty = 0.000000, frequency_penalty = 0.000000, top_k = 40, tfs_z = 1.000000, top_p = 0.950000, typical_p = 1.000000, temp = 0.800000, mirostat = 0, mirostat_lr = 0.100000, mirostat_ent = 5.000000
generate: n_ctx = 512, n_batch = 512, n_predict = 32, n_keep = 0


main: grammar:
<0>root_3 ::= <2>[0-9] | 
<9>root_2 ::= <11>[1-9] <14>root_3 <16>[.-.] <19>[ - ] <22>move <24>[ - ] <27>move <29>[
-
] 
<34>root_4 ::= <36>root_2 <38>root_4 | <42>root_2 
<46>root ::= <48>[1-1] <51>[.-.] <54>[ - ] <57>move <59>[ - ] <62>move <64>[
-
] <67>root_4 
<71>move_5 ::= <73>pawn | <77>nonpawn | <81>castle 
<85>move_9 ::= <87>[+-+#-#] | 
<96>move ::= <98>move_5 <100>move_9 
<104>nonpawn_10 ::= <106>[a-h] | 
<113>nonpawn_11 ::= <115>[1-8] | 
<122>nonpawn_12 ::= <124>[x-x] | 
<131>nonpawn ::= <133>[N-NB-BK-KQ-QR-R] <144>nonpawn_10 <146>nonpawn_11 <148>nonpawn_12 <150>[a-h] <153>[1-8] 
<158>pawn_13 ::= <160>[a-h] <163>[x-x] 
<168>pawn_14 ::= <170>pawn_13 | 
<176>pawn_15 ::= <178>[=-=] <181>[N-NB-BK-KQ-QR-R] 
<194>pawn_16 ::= <196>pawn_15 | 
<202>pawn ::= <204>pawn_14 <206>[a-h] <209>[1-8] <212>pawn_16 
<216>castle_17 ::= <218>[---] <221>[O-O] | 
<228>castle ::= <230>[O-O] <233>[---] <236>[O-O] <239>castle_17 

 A good game:

1. e4 c5
2. Nf3 d6
3. d4 cxd4
4. Nxd4 Nf6
llama_print_timings:        load time =  1231.29 ms
llama_print_timings:      sample time =    32.94 ms /    32 runs   (    1.03 ms per token)
llama_print_timings: prompt eval time =  1214.28 ms /     7 tokens (  173.47 ms per token)
llama_print_timings:        eval time =  5247.10 ms /    31 runs   (  169.26 ms per token)
llama_print_timings:       total time =  6514.52 ms
"Chess" without grammar
% ./main -m $LLAMA_30B_Q4_0 -n 32 -p $'A good game:\n\n'  

main: build = 645 (fd0eb66)
main: seed  = 1686286016
llama.cpp: loading model from /Users/evan/llama-models/30B/ggml-model-q4_0.bin
llama_model_load_internal: format     = ggjt v3 (latest)
llama_model_load_internal: n_vocab    = 32000
llama_model_load_internal: n_ctx      = 512
llama_model_load_internal: n_embd     = 6656
llama_model_load_internal: n_mult     = 256
llama_model_load_internal: n_head     = 52
llama_model_load_internal: n_layer    = 60
llama_model_load_internal: n_rot      = 128
llama_model_load_internal: ftype      = 2 (mostly Q4_0)
llama_model_load_internal: n_ff       = 17920
llama_model_load_internal: n_parts    = 1
llama_model_load_internal: model size = 30B
llama_model_load_internal: ggml ctx size =    0.13 MB
llama_model_load_internal: mem required  = 19756.66 MB (+ 3124.00 MB per state)
.
llama_init_from_file: kv self size  =  780.00 MB

system_info: n_threads = 8 / 12 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 | 
sampling: repeat_last_n = 64, repeat_penalty = 1.100000, presence_penalty = 0.000000, frequency_penalty = 0.000000, top_k = 40, tfs_z = 1.000000, top_p = 0.950000, typical_p = 1.000000, temp = 0.800000, mirostat = 0, mirostat_lr = 0.100000, mirostat_ent = 5.000000
generate: n_ctx = 512, n_batch = 512, n_predict = 32, n_keep = 0


 A good game:

Sir Thomas Gresham, when he was building his famous Exchange at London, had the following dialogue with a mason, whose name was Richard B
llama_print_timings:        load time =  1185.47 ms
llama_print_timings:      sample time =    21.57 ms /    32 runs   (    0.67 ms per token)
llama_print_timings: prompt eval time =  1167.67 ms /     7 tokens (  166.81 ms per token)
llama_print_timings:        eval time =  4977.97 ms /    31 runs   (  160.58 ms per token)
llama_print_timings:       total time =  6188.21 ms
Arithmetic
% ./main -m $LLAMA_30B_Q4_0 -n 32 -p $'Some arithmetic practice:\n\n' \
--grammar 'root  ::= (expr "=" ws num "\n")+
expr  ::= term ([-+*/] term)*
term  ::= ident | num | "(" ws expr ")" ws
ident ::= [a-z] [a-z0-9_]* ws
num   ::= [0-9]+ ws
ws    ::= [ \t\n]*'
main: build = 645 (fd0eb66)
main: seed  = 1686286304
llama.cpp: loading model from /Users/evan/llama-models/30B/ggml-model-q4_0.bin
llama_model_load_internal: format     = ggjt v3 (latest)
llama_model_load_internal: n_vocab    = 32000
llama_model_load_internal: n_ctx      = 512
llama_model_load_internal: n_embd     = 6656
llama_model_load_internal: n_mult     = 256
llama_model_load_internal: n_head     = 52
llama_model_load_internal: n_layer    = 60
llama_model_load_internal: n_rot      = 128
llama_model_load_internal: ftype      = 2 (mostly Q4_0)
llama_model_load_internal: n_ff       = 17920
llama_model_load_internal: n_parts    = 1
llama_model_load_internal: model size = 30B
llama_model_load_internal: ggml ctx size =    0.13 MB
llama_model_load_internal: mem required  = 19756.66 MB (+ 3124.00 MB per state)
.
llama_init_from_file: kv self size  =  780.00 MB

system_info: n_threads = 8 / 12 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 | 
sampling: repeat_last_n = 64, repeat_penalty = 1.100000, presence_penalty = 0.000000, frequency_penalty = 0.000000, top_k = 40, tfs_z = 1.000000, top_p = 0.950000, typical_p = 1.000000, temp = 0.800000, mirostat = 0, mirostat_lr = 0.100000, mirostat_ent = 5.000000
generate: n_ctx = 512, n_batch = 512, n_predict = 32, n_keep = 0


main: grammar:
<0>root_1 ::= <2>expr <4>[=-=] <7>ws <9>num <11>[
-
] 
<16>root_5 ::= <18>root_1 <20>root_5 | <24>root_1 
<28>root ::= <30>root_5 
<34>expr_7 ::= <36>[---+-+*-*/-/] <45>term 
<49>expr_8 ::= <51>expr_7 <53>expr_8 | 
<59>expr ::= <61>term <63>expr_8 
<67>term ::= <69>ident | <73>num | <77>[(-(] <80>ws <82>expr <84>[)-)] <87>ws 
<91>ident_10 ::= <93>[a-z0-9_-_] <100>ident_10 | 
<106>ident ::= <108>[a-z] <111>ident_10 <113>ws 
<117>num_11 ::= <119>[0-9] <122>num_11 | <126>[0-9] 
<131>num ::= <133>num_11 <135>ws 
<139>ws_12 ::= <141>[ - 	-	
-
] <148>ws_12 | 
<154>ws ::= <156>ws_12 

 Some arithmetic practice:

10/2 =
5

9/2 =
4

6/2 =
3

8/2 =
4

llama_print_timings:        load time =  1185.46 ms
llama_print_timings:      sample time =    39.62 ms /    32 runs   (    1.24 ms per token)
llama_print_timings: prompt eval time =  1168.63 ms /     7 tokens (  166.95 ms per token)
llama_print_timings:        eval time =  5056.22 ms /    31 runs   (  163.10 ms per token)
llama_print_timings:       total time =  6284.71 ms
Arithmetic - no grammar
 % ./main -m $LLAMA_30B_Q4_0 -n 32 -p $'Some arithmetic practice:\n\n'                                            
main: build = 645 (fd0eb66)
main: seed  = 1686286388
llama.cpp: loading model from /Users/evan/llama-models/30B/ggml-model-q4_0.bin
llama_model_load_internal: format     = ggjt v3 (latest)
llama_model_load_internal: n_vocab    = 32000
llama_model_load_internal: n_ctx      = 512
llama_model_load_internal: n_embd     = 6656
llama_model_load_internal: n_mult     = 256
llama_model_load_internal: n_head     = 52
llama_model_load_internal: n_layer    = 60
llama_model_load_internal: n_rot      = 128
llama_model_load_internal: ftype      = 2 (mostly Q4_0)
llama_model_load_internal: n_ff       = 17920
llama_model_load_internal: n_parts    = 1
llama_model_load_internal: model size = 30B
llama_model_load_internal: ggml ctx size =    0.13 MB
llama_model_load_internal: mem required  = 19756.66 MB (+ 3124.00 MB per state)
.
llama_init_from_file: kv self size  =  780.00 MB

system_info: n_threads = 8 / 12 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 | 
sampling: repeat_last_n = 64, repeat_penalty = 1.100000, presence_penalty = 0.000000, frequency_penalty = 0.000000, top_k = 40, tfs_z = 1.000000, top_p = 0.950000, typical_p = 1.000000, temp = 0.800000, mirostat = 0, mirostat_lr = 0.100000, mirostat_ent = 5.000000
generate: n_ctx = 512, n_batch = 512, n_predict = 32, n_keep = 0


 Some arithmetic practice:

\begin{code}
package main

import (
    "fmt"
)

func main() {
    fmt.Println(
llama_print_timings:        load time =  1171.65 ms
llama_print_timings:      sample time =    21.37 ms /    32 runs   (    0.67 ms per token)
llama_print_timings: prompt eval time =  1153.88 ms /     7 tokens (  164.84 ms per token)
llama_print_timings:        eval time =  4991.68 ms /    31 runs   (  161.02 ms per token)
llama_print_timings:       total time =  6187.91 ms
JSON
% ./main -m $LLAMA_30B_Q4_0 -n 32 -p $'A bit about me:\n\n' --grammar-file grammars/json.gbnf
main: build = 645 (fd0eb66)
main: seed  = 1686286524
llama.cpp: loading model from /Users/evan/llama-models/30B/ggml-model-q4_0.bin
llama_model_load_internal: format     = ggjt v3 (latest)
llama_model_load_internal: n_vocab    = 32000
llama_model_load_internal: n_ctx      = 512
llama_model_load_internal: n_embd     = 6656
llama_model_load_internal: n_mult     = 256
llama_model_load_internal: n_head     = 52
llama_model_load_internal: n_layer    = 60
llama_model_load_internal: n_rot      = 128
llama_model_load_internal: ftype      = 2 (mostly Q4_0)
llama_model_load_internal: n_ff       = 17920
llama_model_load_internal: n_parts    = 1
llama_model_load_internal: model size = 30B
llama_model_load_internal: ggml ctx size =    0.13 MB
llama_model_load_internal: mem required  = 19756.66 MB (+ 3124.00 MB per state)
.
llama_init_from_file: kv self size  =  780.00 MB

system_info: n_threads = 8 / 12 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 | 
sampling: repeat_last_n = 64, repeat_penalty = 1.100000, presence_penalty = 0.000000, frequency_penalty = 0.000000, top_k = 40, tfs_z = 1.000000, top_p = 0.950000, typical_p = 1.000000, temp = 0.800000, mirostat = 0, mirostat_lr = 0.100000, mirostat_ent = 5.000000
generate: n_ctx = 512, n_batch = 512, n_predict = 32, n_keep = 0


main: grammar:
<0>root ::= <2>object | <6>array 
<10>value ::= <12>object | <16>array | <20>string | <24>number | <28>boolean 
<32>object_9 ::= <34>[,-,] <37>ws <39>string <41>[:-:] <44>ws <46>value 
<50>object_10 ::= <52>object_9 <54>object_10 | 
<60>object_8 ::= <62>string <64>[:-:] <67>ws <69>value <71>object_10 
<75>object_11 ::= <77>object_8 | 
<83>object ::= <85>[{-{] <88>ws <90>object_11 <92>[}-}] 
<97>array_13 ::= <99>[,-,] <102>ws <104>value 
<108>array_14 ::= <110>array_13 <112>array_14 | 
<118>array_12 ::= <120>value <122>array_14 
<126>array_15 ::= <128>array_12 | 
<134>array ::= <136>[[-[] <139>ws <141>array_15 <143>[]-]] 
<148>string_16 ::= <150>[ - 	-	!-!#-[]-~] <161>string_16 | 
<167>string ::= <169>["-"] <172>string_16 <174>["-"] <177>ws 
<181>number_17 ::= <183>[0-9] <186>number_17 | <190>[0-9] 
<195>number ::= <197>number_17 <199>ws 
<203>boolean_18 ::= <205>[t-t] <208>[r-r] <211>[u-u] <214>[e-e] | <219>[f-f] <222>[a-a] <225>[l-l] <228>[s-s] <231>[e-e] 
<236>boolean ::= <238>boolean_18 <240>ws 
<244>ws ::= <246>[ - 	-		-	] <253>ws | 

 A bit about me:

{   "name": "Prakash",    "age":30,      "married": true} [end of text]

llama_print_timings:        load time =  1302.79 ms
llama_print_timings:      sample time =    27.64 ms /    25 runs   (    1.11 ms per token)
llama_print_timings: prompt eval time =  1284.96 ms /     8 tokens (  160.62 ms per token)
llama_print_timings:        eval time =  3959.47 ms /    24 runs   (  164.98 ms per token)
llama_print_timings:       total time =  5292.54 ms
"JSON" - no grammar
 % ./main -m $LLAMA_30B_Q4_0 -n 32 -p $'A bit about me:\n\n'                                                                          
main: build = 645 (fd0eb66)
main: seed  = 1686286615
llama.cpp: loading model from /Users/evan/llama-models/30B/ggml-model-q4_0.bin
llama_model_load_internal: format     = ggjt v3 (latest)
llama_model_load_internal: n_vocab    = 32000
llama_model_load_internal: n_ctx      = 512
llama_model_load_internal: n_embd     = 6656
llama_model_load_internal: n_mult     = 256
llama_model_load_internal: n_head     = 52
llama_model_load_internal: n_layer    = 60
llama_model_load_internal: n_rot      = 128
llama_model_load_internal: ftype      = 2 (mostly Q4_0)
llama_model_load_internal: n_ff       = 17920
llama_model_load_internal: n_parts    = 1
llama_model_load_internal: model size = 30B
llama_model_load_internal: ggml ctx size =    0.13 MB
llama_model_load_internal: mem required  = 19756.66 MB (+ 3124.00 MB per state)
.
llama_init_from_file: kv self size  =  780.00 MB

system_info: n_threads = 8 / 12 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 | 
sampling: repeat_last_n = 64, repeat_penalty = 1.100000, presence_penalty = 0.000000, frequency_penalty = 0.000000, top_k = 40, tfs_z = 1.000000, top_p = 0.950000, typical_p = 1.000000, temp = 0.800000, mirostat = 0, mirostat_lr = 0.100000, mirostat_ent = 5.000000
generate: n_ctx = 512, n_batch = 512, n_predict = 32, n_keep = 0


 A bit about me:

A former teacher, now a full-time writer. I am the author of two novels: _The Man in the Moon_ and _The Riddle
llama_print_timings:        load time =  1291.32 ms
llama_print_timings:      sample time =    21.48 ms /    32 runs   (    0.67 ms per token)
llama_print_timings: prompt eval time =  1274.63 ms /     8 tokens (  159.33 ms per token)
llama_print_timings:        eval time =  4990.01 ms /    31 runs   (  160.97 ms per token)
llama_print_timings:       total time =  6306.01 ms

Approach

Grammar API

The llama API accepts a 16-bit binary encoding of a context-free grammar. My hope was that a serialized format would improve cache locality while simplifying the data structures and C API:

// The binary format represents one or more production rules, each with one or more alternate
// defininitions:
//
// (<rule_id: u16> (<alt_size: u16> <alt_size * u16>)+ 0000)+ FFFF
//
// rule_ids should be assigned sequentially from zero but may appear out of order. Each
// rule alternate is a sequence of zero or more symbols, each prefixed with size:
//
// (<sym_size: u16> <sym_size * u16>)* 0000
//
// A symbol of size 1 is interpreted as a rule reference (whose value is the single following
// u16). Symbols sized greater than 1 are interpreted as inclusive pairs of 16-bit chars to
// match. Note that symbol sizes greater than 7FFF are reserved for future use.

The reserved 7FFF - FFFF symbol sizes could maybe be used eventually to encode biases or callbacks.

Sampling

The grammar sampling code models a nondeterministic pushdown automaton, maintaining N stacks for the possible parse states. Sampling a token is done in two steps: a sampling API (llama_sample_grammar) and a post-sample transformation (llama_grammar_accept_token). The former filters candidate tokens just based on peeking at the next char expected by each parse stack, which should be fast. The latter then truncates the generated token if needed to one that satisfies the grammar, while updating the parse state.

As a special case, the sampling code advances the grammar by a single space and adds any valid tokens at that state as well, since many tokens (about a third?) start with a single space.

Examples

Adds a --grammar argument to main taking a simple extended BNF to constrain generations. The parser for this format is implemented in examples/grammar-parser.{h,cpp}:

// ... Supports character
// ranges, grouping, and repetition operators. As an example, a grammar for
// arithmetic might look like:
//
// root  ::= expr
// expr  ::= term ([-+*/] term)*
// term  ::= num | "(" space expr ")" space
// num   ::= [0-9]+ space
// space ::= [ \t\n]*

The root rule identifies the start of the grammar.

Caveats

  • the binary format makes the code harder to understand and more brittle
  • the grammar contemplates 16-bit chars but it's just being applied to the 8-bit UTF-8 chars in token strings currently
  • the 1-char lookahead sampling is probably biasing generations in a weird way; further investigation on quality of outputs is probably needed

ejones avatar Jun 09 '23 05:06 ejones

Suggest taking a file as grammar parameter and put several examples like what we did for prompts (in .\prompts folder).

howard0su avatar Jun 10 '23 01:06 howard0su

Incredibly useful contribution. It's really amazing how much this simplifies many use cases.

I agree that it would be better if the grammar came from a file.

Two snags I hit while trying this out:

  • it crashes with --prompt-cache
  • any empty lines in the grammar cause a crash

Some additional thoughts:

  • Would love to have the grammars support empty lines and comments
  • I wonder if the grammar could be compiled into a tensor of state transitions and run on the GPU
  • I wonder if there is an optimization where the next token is already known form the grammar we could skip the inference and just add it? In many types of grammars like json or html that could really speed up generation
  • I think it's worth allowing to reference full tokens form the grammar. Maybe something like @“ token” or @13432 Id of token.

tobi avatar Jun 10 '23 20:06 tobi

Very nice! I am wondering what is the rationale for not including the parser in the llama.cpp API. Without it, most downstream users will be forced to manually make a copy of the parser in their code to support the feature, which is not great. Also for usability, I think it would be a good idea to keep a copy of the binary grammar in llama_grammar, rather than asking the users to keep the provided copy alive. The overhead would be minimal, and it would simplify the code of downstream users.

slaren avatar Jun 11 '23 12:06 slaren

Thanks all! Just added support for grammar files (with examples) and updated the grammar syntax to add shell-style comments and allow empty lines between rules, as well as newlines inside parenthesized groups.

it crashes with --prompt-cache

I wonder if that was #1699 ? If so, should be fixed now

I wonder if the grammar could be compiled into a tensor of state transitions and run on the GPU

Sounds cool, I don't know enough about GPU programming to comment on that myself. The grammar participates in the sampling layer, and I'm not sure if that leverages the GPU currently.

I wonder if there is an optimization where the next token is already known form the grammar we could skip the inference and just add it?

This is definitely possible. That said, AFAIK the token would still need to be evaluated, and that seems to be the bottleneck. Maybe the optimization comes in being able to batch eval strings of such tokens?

I think it's worth allowing to reference full tokens form the grammar

Neat idea. Would that be more of an optimization or to reference tokens that can't be expressed textually?

what is the rationale for not including the parser in the llama.cpp API.

Honestly, I was trying to reduce the changes to llama.cpp itself. Agree it would be more convenient in the API.

I think it would be a good idea to keep a copy of the binary grammar

Makes sense. I left that out of this round of changes - if it's desired to have the grammar parser in the llama API, this may naturally fit with that change.

ejones avatar Jun 12 '23 04:06 ejones

First, this is amazing work.

This makes me wonder whether the entire sampling API should be pulled into something like llama_samplers instead. External samplers can evolve independently of the core API.

The existing functions can be kept for compatibility. AFAIK, the only thing we need is to expose the RNG. And even then, the existence of that inside a state/context is debatable. The context window is already managed by user code so why not sampling?

This reminds me a lot of: https://lmql.ai/. There is also https://github.com/1rgs/jsonformer where the input is a json schema which is not always easy to express in BNF.

AFAIK the token would still need to be evaluated

Would it though? We just immediately add it to the context. It is done manually in user code now.

Maybe the optimization comes in being able to batch eval strings of such tokens?

AFAIK, that's the case. The initial prompt and the user input are submitted in a large batch. The inference loop just feed the single chosen token back until eos.

The grammar participates in the sampling layer, and I'm not sure if that leverages the GPU currently.

The current sampling is CPU.

bullno1 avatar Jun 12 '23 15:06 bullno1

This makes me wonder whether the entire sampling API should be pulled into something like llama_samplers instead.

one of the discussion points for adding more llm generic tooling back into ggml(repo) was moving the sampler there. but afaik nothing happened yet :)

Green-Sky avatar Jun 12 '23 16:06 Green-Sky

There is also https://github.com/1rgs/jsonformer where the input is a json schema

Was planning to tackle this next. I've got it more or less working locally in a branch off of this, at least with the examples on jsonformer's README. It uses a Python script to generate a JSON BNF that conforms to the schema.

ejones avatar Jun 12 '23 19:06 ejones

Great stuff!

I'm still wrapping my head around this.

  • Yes, this can become part of a llama.cpp or ggml sampling API, but I guess for now we can keep it as example and see what are the pros and cons and learn how to use it most efficiently
  • What happens when then next N > 1 tokens are uniquely determined by the grammar? I guess we will sample them one by one, correct? What would it take to make it so that they are submitted to be processed as a batch? This would significantly speed up the inference in such cases

ggerganov avatar Jun 15 '23 18:06 ggerganov

  • Yes, this can become part of a llama.cpp or ggml sampling API, but I guess for now we can keep it as example and see what are the pros and cons and learn how to use it most efficiently

To clarify, this PR adds the core sampling functionality in llama.cpp, leaving the grammar parser out in examples. Should that all be moved to examples or just left as is?

  • What happens when then next N > 1 tokens are uniquely determined by the grammar? I guess we will sample them one by one, correct? What would it take to make it so that they are submitted to be processed as a batch? This would significantly speed up the inference in such cases

Yes, that's correct. I think that's doable, I can take a stab at that.

ejones avatar Jun 16 '23 04:06 ejones

the grammar contemplates 16-bit chars but it's just being applied to the 8-bit UTF-8 chars in token strings currently

~~I don't understand this part. So it is converting to UTF-16?~~

~~Another option would be to use token values but it will be more limiting.~~

EDIT: I read through the code.

The grammar doesn't care about the text encoding. It could work with any encoding, provided that the rules match the characters correctly.

The parser doesn't understand UTF-8 so it will create rules that don't match as the user expects.

For example, if I wanted to create a rule to match all Hiragana characters, I should be able to write:

[ぁ-ゖ]

However the parser doesn't see it as two characters separated by -, instead:

[\xe3\x81\x81-\xe3\x82\x96]

But the correct rule should be something like this?

"\xe3" [\x81-\x82] [\x81-\x96]

SlyEcho avatar Jun 16 '23 07:06 SlyEcho

Just dont use repeat penalties to get best grammar as llama can

ivanstepanovftw avatar Jun 16 '23 11:06 ivanstepanovftw

To clarify, this PR adds the core sampling functionality in llama.cpp, leaving the grammar parser out in examples. Should that all be moved to examples or just left as is?

It's fine the way it is

ggerganov avatar Jun 16 '23 15:06 ggerganov

FWIW I'm adapting this code into an analogous feature for models running on torch. In my implementation, I'm doing grammar enforcement logit masking on the GPU across the full token set before selecting candidates: https://github.com/Shopify/torch-grammar/blob/df23e354083c909c70120e256ed34036c93f6714/grammar_sampler.py#L232-L239. The same strategy would probably work here if anyone was super motivated to try it.

burke avatar Jun 16 '23 20:06 burke

@SlyEcho

The parser doesn't understand UTF-8 so it will create rules that don't match as the user expects.

Yeah, my rough plan for Unicode support was to store UTF-16 code units in the binary grammar, re-encoding the token strings as UTF-16 when sampling. And maybe limiting the BNF grammar to ascii with Unicode escapes. But I punted on Unicode support for this initial version.

ejones avatar Jun 18 '23 02:06 ejones

So, this grammar feature works in two parts:

  1. Parse the grammar into a bytecode grammar program.
  2. During generation check the grammar using a state machine.

My question is about this intermediate bytecode (or wordcode?) representation. Is there some performance benefit for it? Is the text grammar very hard to parse so it needs to be in binary form?

Would it not be easier to parse into and use some kind of graph data structures using C/C++ structs?

SlyEcho avatar Jun 18 '23 14:06 SlyEcho

Is there some performance benefit for it? Would it not be easier to parse into and use some kind of graph data structures using C/C++ structs?

Yes, I agree, it would be simpler and clearer to use an AST expressed in structs. I can make that adjustment.

I went with the binary format for a couple reasons: first, I hoped it would have good cache locality, although I admit I haven't benchmarked that or anything, and this could be achieved with structs as well. Second, I wanted to reduce the overall size of changes to llama.cpp itself, and keep the C API simple. I wasn't sure what the tolerance for changes there would be. This included keeping the parser external, meaning the AST/parser output is an input to the C API.

ejones avatar Jun 18 '23 22:06 ejones

Update on this: working on refactoring this to store the parsed grammar as structs. Also trying to think through the Unicode handling a bit more.

ejones avatar Jun 21 '23 22:06 ejones

Thank you for this useful addition. I've attempted to transform the BNF grammar for SQL-2003 into a compatible format and read it in with this branch. After around a stack depth of 58k+ of llama_grammar_advance_stack() calling itself recursively I'm hitting a segfault when malloc() is called. I'm still continuing to debug this but just thought I'd share it in the meantime in case it is a useful test case, at least in terms of testing a large grammar file.

Fyi the attached .gbnf file is not production-ready in terms of representing valid SQL grammar, as many of the less-standard tokens are simply marked "TODO" for now. (The plan is to verify its size is compatible with this branch first before cleaning it up.)

sql-2003-2.gbnf.gz

mattpulver avatar Jun 25 '23 17:06 mattpulver

Thanks @mattpulver !

ejones avatar Jun 25 '23 19:06 ejones

I've been playing with this for the past week and it works splendidly. Truly a game-changer for building robust LLM applications. Thanks for the great work @ejones!

zakkor avatar Jun 27 '23 23:06 zakkor

@ejones What unicode library do you have in mind? I'm not exactly sure what token space looks like when it comes to non-English texts but I would at least expect some kind of normalisation (in Unicode sense) or clusters would have to be properly supported at least for some grammars. I'm currently trying to work out a GPT-4 based solution that will hopefully be able to generate highly-specific (overfitted, if you will) grammars to guide generation on the fly, & funnily enough a great deal of it is in Ukrainian and sometimes a great deal of the actual grammar would represent specific combinations in the local, and more generic in the global. Ukrainian (cyrillic) is fairly straightforward in the EGC sense but other scripts like Arabic, which is apparently getting more representation due to https://noor.tii.ae/ and similar projects.

tucnak avatar Jun 28 '23 10:06 tucnak

@SlyEcho updated to add a bit more structure to the in-memory grammar. Kept it somewhat flat still (each rule is an array of atomic elements) for memory locality and simplicity in the C API. Let me know what you think. Also added UTF-8 decoding such that the "characters" of the grammar are now code points. It should now support Unicode in the grammar files and in generations.

ejones avatar Jun 29 '23 05:06 ejones

Re: SQL grammar support Your recent commit f8baad2 added useful diagnostic info to identify malformed/undefined rules which existed in the above attachment sql-2003-2.gbnf.gz. This helped me to create an improved sql-92.gbnf file that is both smaller and passes the improved grammar checking.

After making the fixes I'm still getting a similar SEGFAULT as before:

$ gdb --args bin/main --grammar-file sql-92.gbnf -n32 -p "How many states are in the United States?\n"
...
high ::= [2] | [H] [i] [g] [h] left-paren [2] right-paren 
root ::= direct-SQL-statement 

Program received signal SIGSEGV, Segmentation fault.
0x00007ffff7aab70d in ?? () from /usr/lib/libc.so.6
(gdb) bt
#0  0x00007ffff7aab70d in ?? () from /usr/lib/libc.so.6
#1  0x00007ffff7aace93 in ?? () from /usr/lib/libc.so.6
#2  0x00007ffff7aad82a in malloc () from /usr/lib/libc.so.6
#3  0x00007ffff7cb089d in operator new (sz=310240) at /usr/src/debug/gcc/gcc/libstdc++-v3/libsupc++/new_op.cc:50
#4  0x000055555557b4c4 in std::__new_allocator<llama_grammar_element const*>::allocate (this=0x7fffff7ff300, __n=38780)
    at /usr/include/c++/13.1.1/bits/new_allocator.h:147
#5  0x000055555557a3f3 in std::allocator_traits<std::allocator<llama_grammar_element const*> >::allocate (__n=38780, 
    __a=...) at /usr/include/c++/13.1.1/bits/alloc_traits.h:482
#6  std::_Vector_base<llama_grammar_element const*, std::allocator<llama_grammar_element const*> >::_M_allocate (
    this=0x7fffff7ff300, __n=38780) at /usr/include/c++/13.1.1/bits/stl_vector.h:378
#7  0x000055555559ee32 in std::vector<llama_grammar_element const*, std::allocator<llama_grammar_element const*> >::_M_realloc_insert<llama_grammar_element const* const&> (this=0x7fffff7ff300, __position=0x7ffff7bf23c0)
    at /usr/include/c++/13.1.1/bits/vector.tcc:459
#8  0x0000555555596f70 in std::vector<llama_grammar_element const*, std::allocator<llama_grammar_element const*> >::push_back (this=0x7fffff7ff300, __x=@0x7fffff7ff2d8: 0x5555559829e0) at /usr/include/c++/13.1.1/bits/stl_vector.h:1289
#9  0x0000555555580df8 in llama_grammar_advance_stack (rules=std::vector of length 1351, capacity 1351 = {...}, 
    stack=std::vector of length 19391, capacity 38780 = {...}, 
    new_stacks=std::vector of length 77557, capacity 131072 = {...})
    at ejones/llama.cpp/llama.cpp:1900
#10 0x0000555555580e0f in llama_grammar_advance_stack (rules=std::vector of length 1351, capacity 1351 = {...}, 
    stack=std::vector of length 19391, capacity 38780 = {...}, 
    new_stacks=std::vector of length 77557, capacity 131072 = {...})
    at ejones/llama.cpp/llama.cpp:1902
#11 0x0000555555580e0f in llama_grammar_advance_stack (rules=std::vector of length 1351, capacity 1351 = {...}, 
    stack=std::vector of length 19391, capacity 38780 = {...}, 
    new_stacks=std::vector of length 77557, capacity 131072 = {...})
    at ejones/llama.cpp/llama.cpp:1902
...
#58189 0x0000555555580e0f in llama_grammar_advance_stack (rules=std::vector of length 1351, capacity 1351 = {...}, 
    stack=std::vector of length 1, capacity 1 = {...}, new_stacks=std::vector of length 77557, capacity 131072 = {...})
    at ejones/llama.cpp/llama.cpp:1902
#58190 0x0000555555580e0f in llama_grammar_advance_stack (rules=std::vector of length 1351, capacity 1351 = {...}, 
    stack=std::vector of length 1, capacity 1 = {...}, new_stacks=std::vector of length 77557, capacity 131072 = {...})
    at ejones/llama.cpp/llama.cpp:1902
#58191 0x000055555558150b in llama_grammar_init (rules=0x55555598f560, n_rules=1351, start_rule_index=1350)
    at ejones/llama.cpp/llama.cpp:2043
#58192 0x000055555555c188 in main (argc=9, argv=0x7fffffffe578)
    at ejones/llama.cpp/examples/main/main.cpp:311

It appears as if the new_stacks in llama_grammar_advance_stack() is growing as O(n^2) with each recursive iteration which may be a problem for large grammar files like this one. Roughly speaking: is it possible to navigate through the stack tree without having to materialize each step?

mattpulver avatar Jun 29 '23 15:06 mattpulver

Gonna look into this soon.

SlyEcho avatar Jun 29 '23 15:06 SlyEcho

@mattpulver I reproduced the segfault and it appears the problem in this case is left-recursive rules like query-expression -> non-join-query-expression -> query-expression or query-term -> non-join-query-term -> query-term. Since the grammar is processed top-down (and there's no special handling of left recursion), it's infinitely recursing on expanding these initial references.

A workaround would be to adjust the rules to eliminate left-recursion. E.g., maybe something along the lines of this?

query-expression ::= query-term (("UNION" | "EXCEPT") "ALL"? corresponding-spec? query-term)* |
                   ... (other cases if needed)

We could probably at least detect left recursion and error out early. Longer term, we could look into an implementation that would support such grammars.

ejones avatar Jul 01 '23 04:07 ejones

@tucnak at this point the grammars are defined over code points, with no specific handling or recognition of grapheme clusters. A particular grammar could recognize grapheme clusters (and in effect normalize them), and maybe that's what you're working towards with your solution?

ejones avatar Jul 01 '23 04:07 ejones

great contribution @ejones kudos! can't look forward to see this merged!

Just for reference, been trying this locally and works like a charm! I've went a bit ahead and tried this with the golang bindings ( branch at: https://github.com/go-skynet/go-llama.cpp/tree/grammar ) and LocalAI (https://github.com/go-skynet/LocalAI), result is that now is possible to emulate OpenAI functions and run directly their examples:

localai-functions-1

What I wanted to do is give more data points and highlight that it chooses correctly also to not use any of the functions, so from a first hands-on with it (and from a personal empirical set of tests) it looks that the 1-char lookahead sampling is good enough if the model is "good" enough (I've tested with WizardLM 7b), but it's very sensible to the prompt:

functions-2

mudler avatar Jul 02 '23 15:07 mudler

JFYI After playing with it a bit more, I've bumped into this while trying on ARM64+CUDA:

Jul 04 20:37:13 localhost local-ai[34380]: LLAMA_ASSERT: /usr/local/LocalAI/go-llama/llama.cpp/llama.cpp:2479: !new_stacks.empty()                                                                                   
Jul 04 20:37:13 localhost local-ai[34380]: SIGABRT: abort     

This is me trying with the binding. Update: I can't reproduce with llama.cpp. Ignore me, must be something in the binding which is not correct.

mudler avatar Jul 04 '23 20:07 mudler

@mudler thanks for the feedback! Re: OpenAI functions, I also have a draft up at #1887 with a script to convert JSON schemas to grammars. Re: that assertion, that is triggered when the sampled token doesn't match the grammar at all. If you do run into it on llama.cpp, let me know the inputs and happy to look into it :).

ejones avatar Jul 06 '23 02:07 ejones

@mudler thanks for the feedback! Re: OpenAI functions, I also have a draft up at #1887 with a script to convert JSON schemas to grammars. Re: that assertion, that is triggered when the sampled token doesn't match the grammar at all. If you do run into it on llama.cpp, let me know the inputs and happy to look into it :).

yes, awesome job! I've looked at that PR indeed, and slightly adapted to Golang to generate grammars directly from the requests - my first attempts where more simple though, with a chain of let first choose an action -> and then fill the params (to force it to some kind of 'reasoning' step)

However, what I'm seeing is quite weird - it doesn't happen when using llama.cpp directly, but only when using it with the golang bindings (https://github.com/go-skynet/go-llama.cpp) on a particular setup I have (it just happens on ARM, on my x86_64 machine just runs fine). The same grammar (basically it's equivalent output from your sample in #1887 ) works fine on x86_64 but crashes with ARM+CUDA with the error above.

I'm suspecting something weird going on in the toolchain package combination (gcc/nvcc) - I've tried to trace it with gdb back with no luck so far, seems indeed that there is no match with the grammar rules (even if it does match on x86_64!).

I really appreciate your help! Thank you so much, but I don't want to bother you. It seems like running llama.cpp directly isn't causing any issues. I'll collect more data and see if I can figure out if it's something that can be replicated or not.

mudler avatar Jul 06 '23 06:07 mudler