llama.cpp icon indicating copy to clipboard operation
llama.cpp copied to clipboard

llama : speed-up grammar sampling

Open ggerganov opened this issue 7 months ago • 30 comments

There have been a few reports where the grammar sampling can significantly degrade the performance. It would be nice to profile and optimize the implementation - there should be room for improvements.

Already on-going efforts:

  • #4210
  • #4213

Probably worth looking in multi-threading the implementation as well.

ggerganov avatar Nov 25 '23 17:11 ggerganov

#3980 and this suggestion might also help a bit https://github.com/ggerganov/llama.cpp/issues/3980#issuecomment-1826269575

I would have expected the compiler to optimize it straight away 🤷🏻

ExtReMLapin avatar Nov 27 '23 07:11 ExtReMLapin

Would an integration of Outlines help? Like they are doing with vLLM: https://github.com/outlines-dev/outlines/issues/163

gottlike avatar Nov 28 '23 07:11 gottlike

@ExtReMLapin This copy is used only in the speculative example. Even if it helps there, it won't have any effect on the general use case. Still, a PR is welcome

@gottlike An efficient low-level solution as the one we currently have seems like a better approach to me.

ggerganov avatar Nov 28 '23 07:11 ggerganov

I noticed that inference gets at some point exponentially slower when there are a lot of deeply nested, but open grammars. With open I mean a lot of different possibilities. As example I am trying to work on PydanticModel -> JsonSchema -> Grammar and when the model outputs a list of nested subobjects this effect comes when the list is long and at some point it gets stuck.

shroominic avatar Dec 01 '23 17:12 shroominic

@shroominic on my end it just gets slower the longer it is in printing the json array, no nested objects.

ExtReMLapin avatar Dec 01 '23 17:12 ExtReMLapin

I found similar exponential slowdown as mentioned by @shroominic for my use case, which is to generate code in a language similar to OCaml. The speed of generation was very fast at the first 200 tokens but increased to more than 400 seconds per token as I approach 300 tokens.

I plotted the grammar stack size and duration per token over time and found stack size to be the main factor in the slow down. The number of grammar stacks can go up 800K for my grammar. I'm not very familiar with the grammar sampling algorithm used in llama.cpp but I suspect it's exponential in the length of the parsed string. Polynomially bounded parsing algorithms like the Earley parser might help avoid the exponential blowup.

Grammar
root ::= [ \t\n]* exp

ws ::= [ \t\n]+
w ::= [ \t]*

comment ::= "#" [^#]* "#" [ \t]+ [\n]? [ \t]*

### Expressions

exp ::= comment* sequence-exp

sequence-exp ::= tuple-exp (w ";" ws tuple-exp)*

tuple-exp ::= cons-exp (w "," ws cons-exp)*

cons-exp ::= binary-exp (w "::" w binary-exp)*

binary-exp ::= unary-exp (ws binary-op ws unary-exp)*

unary-exp ::= unary-op* function-app-exp

function-app-exp ::= primary-exp (w "(" w exp w ")" w)*

primary-exp ::= bool |
    integer |
    float |
    string |
    variable |
    "()" |
    "[]" |
    constructor |
    constructor-app |
    parenthesized-exp |
    list-exp |
    let-exp |
    if-exp |
    case-exp |
    test-exp |
    type-alias |
    fun

constructor-app ::= constructor "(" w exp w ")"
parenthesized-exp ::= "(" w exp w ")"
list-exp ::= "[" exp ("," ws exp)* "]"
let-exp ::= "let" ws pat ws "=" ws exp ws "in" ws exp
if-exp ::= "if" ws exp ws "then" ws exp ws "else" ws exp
case-exp ::= "case" ws exp (ws "|" ws pat ws "=>" ws exp)+ ws "end"
test-exp ::= "test" ws exp ws "end"
type-alias ::= "type" ws constructor ws "=" ws typ ws "in" ws exp
fun ::= "fun" ws pat ws "->" ws exp

type-variable ::= [a-z][A-Za-z0-9_]*
constructor ::= [A-Z][A-Za-z0-9_]*
variable ::= ([_a-bdg-hj-kn-qu-z][A-Za-z0-9_.]*)|(("s" ([.0-9A-Z_a-su-z][A-Za-z0-9_.]*)?)|("st" ([.0-9A-Z_a-qs-z][A-Za-z0-9_.]*)?)|("str" ([.0-9A-Z_a-tv-z][A-Za-z0-9_.]*)?)|("stru" ([.0-9A-Z_a-bd-z][A-Za-z0-9_.]*)?)|("struc" ([.0-9A-Z_a-su-z][A-Za-z0-9_.]*)?)|("struct" [A-Za-z0-9_.]+)|("c" ([.0-9A-Z_b-z][A-Za-z0-9_.]*)?)|("ca" ([.0-9A-Z_a-rt-z][A-Za-z0-9_.]*)?)|("cas" ([.0-9A-Z_a-df-z][A-Za-z0-9_.]*)?)|("case" [A-Za-z0-9_.]+)|("i" ([.0-9A-Z_a-mo-z][A-Za-z0-9_.]*)?)|("in" [A-Za-z0-9_.]+)|("r" ([.0-9A-Z_a-df-z][A-Za-z0-9_.]*)?)|("re" ([.0-9A-Z_a-bd-z][A-Za-z0-9_.]*)?)|("rec" [A-Za-z0-9_.]+)|("t" ([.0-9A-Z_a-df-z][A-Za-z0-9_.]*)?)|("te" ([.0-9A-Z_a-rt-z][A-Za-z0-9_.]*)?)|("tes" ([.0-9A-Z_a-su-z][A-Za-z0-9_.]*)?)|("test" [A-Za-z0-9_.]+)|("l" ([.0-9A-Z_a-df-z][A-Za-z0-9_.]*)?)|("le" ([.0-9A-Z_a-su-z][A-Za-z0-9_.]*)?)|("let" [A-Za-z0-9_.]+)|("m" ([.0-9A-Z_b-z][A-Za-z0-9_.]*)?)|("ma" ([.0-9A-Z_a-su-z][A-Za-z0-9_.]*)?)|("mat" ([.0-9A-Z_a-bd-z][A-Za-z0-9_.]*)?)|("matc" ([.0-9A-Z_a-gi-z][A-Za-z0-9_.]*)?)|("match" [A-Za-z0-9_.]+)|("f" ([.0-9A-Z_a-tv-z][A-Za-z0-9_.]*)?)|("fu" ([.0-9A-Z_a-mo-z][A-Za-z0-9_.]*)?)|("fun" [A-Za-z0-9_.]+)|("e" ([.0-9A-Z_a-mo-z][A-Za-z0-9_.]*)?)|("en" ([.0-9A-Z_a-ce-z][A-Za-z0-9_.]*)?)|("end" [A-Za-z0-9_.]+))
bool ::= "true" | "false"
integer ::= [0-9]+
float ::= [0-9]* "." [0-9]+
string ::= "\"" [^"]* "\""

unary-op ::= "-" | "!"
binary-op-int ::= "+" | "-" | "*" | "/" | "<" | ">" | "<=" | ">=" | "==" | "!="
binary-op-float ::= "+." | "-." | "*." | "/." | "<." | ">." | "<=." | ">=." | "==." | "!=."
binary-op-string ::= "$==" | "@"
binary-op-logic ::= "&&"
binary-op ::= binary-op-int | binary-op-float | binary-op-string | binary-op-logic

### Patterns

pat ::= type-ascription-pat

type-ascription-pat ::= tuple-pat (w ":" ws typ)*

tuple-pat ::= cons-pat (w "," ws cons-pat)*

cons-pat ::= primary-pat (w "::" w primary-pat)*

primary-pat ::=
    bool |
    integer |
    float |
    string |
    variable |
    "()" |
    "[]" |
    "_" |
    constructor |
    constructor-app-pat |
    parenthesized-pat |
    list-pat

constructor-app-pat ::= constructor "(" w pat w ")"
parenthesized-pat ::= "(" w pat w ")"
list-pat ::= "[" pat (w "," ws pat)* "]"

### Types

typ ::= arrow-typ

arrow-typ ::= tuple-typ (ws "->" ws tuple-typ)*

tuple-typ ::= primary-typ (w "," ws primary-typ)*

primary-typ ::=
    "Unit" |
    "Int" |
    "Float" |
    "Bool" |
    "String" |
    type-variable |
    constructor |
    constructor-def (ws "+" ws constructor-def)+ |
    parenthesized-typ |
    list-typ

parenthesized-typ ::= "(" w typ w ")"
list-typ ::= "[" w typ w "]"
constructor-def ::= constructor | constructor "(" w typ w ")"
Prompt
### Option ###

# Represent values that may or may not exist. #
type Option =
  + Some(?)
  + None
in

# Compare if two Options are equal #
# equal: ((?, ?) -> Bool) -> (Option, Option) -> Bool #
let equal: ((?, ?) -> Bool) -> (Option, Option) -> Bool =
    fun eq, os ->
        case os
        | Some(x), Some(y) => eq(x, y)
        | None, None => True
        | _, _ => False
        end
in

### Result ###

# A Result is either Ok meaning the computation succeeded, #
# or it is an Err meaning that there was some failure. #
type Result =
  + Ok(a)
  + Err(b)
in

# Compare if two Results are equal #
# equal: ((a, a) -> Bool) -> (Result, Result) -> Bool #
let equal: ((a, a) -> Bool) -> (Result, Result) -> Bool =
    fun eq, rs ->
        case rs
        | Ok(e1), Ok(e2) => eq(e1, e2)
        | Error(e1), Error(e2) => e1 $== e2
        | _ => false
        end
in


### JSON ###
# This module helps you convert between Hazel values and JSON values. #

# A JSON value type #
type Value =
  + Object([(String, Value)])
  + Array([Value])
  + Str(String)
  + Number(Float)
  + Boolean(Bool)
  + Null 
in

# Check if two JSON values are equal #
# equal : (Value,Value) -> Bool #
let equal : (Value,Value) -> Bool =
fun a, b ->
    case (a, b)
    | Object(o1), Object(o2) => List.equal(
        fun (s1, v1), (s2, v2) ->
            s1 $== s2 && equal(v1, v2), o1, o2)
    | Array(a1), Array(a2) => List.equal(equal, a1, a2)
    | Str(s1), Str(s2) => s1 $== s2
    | Number(n1), Number(n2) => n1 ==. n2
    | Boolean(b1), Boolean(b2) => if b1 then b2 else !b2
    | Null, Null => true
    | _ => false
  end 
in

# JSON Encoder #

# Convert a string to a JSON string #
# value_of_string : String -> Value #
let value_of_string : String -> Value =
    fun s -> Str(s) 
in

# Convert an integer to a JSON integer #
# value_of_int : Int -> Value #
let value_of_int : Int -> Value =
    fun i -> Number(float_of_int(i)) 
in

# Convert a float to a JSON float #
# value_of_float : Float -> Value #
let value_of_float : Float -> Value =
    fun f -> Number(f) 
in

# Convert a boolean to a JSON boolean #
# value_of_bool : Bool -> Value #
let value_of_bool : Bool -> Value =
    fun b -> if b then Boolean(true) else Boolean(false)
in

# Convert a null to a JSON null #
# value_of_null : Value #
let value_of_null : Value = Null in

# Convert a list of JSON values to a JSON array #
# value_of_list : (a -> Value, [a]) -> Value #
let value_of_list : (a -> Value, [a]) -> Value =
  fun (func, entries) ->
    Array(
      List.rev(List.fold_left(
        fun l, e-> func(e)::l, [], entries)))
in

# Convert a dictionary of JSON values to a JSON object #
# value_of_object : [(String, Value)] -> Value #
let value_of_object : [(String, Value)] -> Value =
    fun entries -> Object(entries)
in

# JSON Decoder #
# A Decoder decodes a JSON value into a Hazel value, or return an Err on failure. #
type Decoder = Value -> Result in

# Decodes a JSON string into a string #
# string_of_value : Decoder #
let string_of_value : Decoder =
    fun v ->
        case v
        | Str(s) => Ok(s)
        | _ => Err("Cannot unpack value as a String")
        end
in

# Decodes a JSON boolean into a boolean #
# bool_of_value : Decoder #
let bool_of_value : Decoder =
  fun v ->
    case v
    | Boolean(b) => Ok(b)
    | _ => Err("Cannot unpack value as a Bool")
    end
in

# Decodes a JSON integer into an integer #
# int_of_value : Decoder #
let int_of_value : Decoder =
    fun v ->
        case v
        | Number(n) =>
            if floor(n) ==. n then
                # n is a whole number #
                Ok(floor(n))
            else
                # n is a floating point #
                Err("Cannot unpack a float value as an Int")
        | _ => Err("Cannot unpack value as an Int") 
        end
in

# Decodes a JSON float into a float #
# float_of_value : Decoder #
let float_of_value : Decoder =
fun v ->
    case v
    | Number(n) => Ok(floor(n))
    | _ => Err("Cannot unpack value as a Float")
    end
in

# Decodes a JSON null into a null #
# null_of_value : Decoder #
let null_of_value : Decoder =
  fun v ->
    case v
    | Null => Ok(None)
    | _ => Err("Cannot unpack value as a None")
    end
in

# Parsers #
# Try a bunch of different decoders. #
# This can be useful if the JSON may come in a couple different formats. #
# one_of : [Decoder] -> Decoder #
let one_of : [Decoder] -> Decoder =
    fun decoders -> fun v ->
        case decoders
        | decoder::decoders =>
            result_map_err(fun _ -> one_of(decoders)(v), decoder(v))
        | [] => Err("one_of failed to decode value")
        end
    in

# Transform a decoder. #
# map : ((a -> b), Decoder) -> Decoder #
let map : ((a -> b), Decoder) -> Decoder =
    fun (func, decoder) -> fun v ->
        case decoder(v)
        | Err(e) => Err(e)
        | Ok(o) => func(o)
in

# Create decoders that depend on previous results. #
# and_then: ((a -> Decoder), Decoder) -> Decoder #
let and_then: ((a -> Decoder), Decoder) -> Decoder =
    fun (func, decoder) ->
        fun v ->
            case decoder(v) 
            | Err(e) => Err(e)
            | Ok(o)=> func(o)(v)
            end
in

# Decode a nullable JSON value into a Hazel value. #
# nullable : Decoder -> Decoder #
let nullable : Decoder -> Decoder =
    fun decoder ->
        one_of([
            map(fun s -> Some(s), decoder),
            null_of_value
        ])
in

# Decode a JSON array into a Hazel List. #
# list : Decoder -> Decoder #
let list : Decoder -> Decoder =
    fun elem_decoder ->
    fun v ->
        case v 
        | Array(arr) => 
    case arr 
    | head::tail =>
    case elem_decoder(head) 
    | Ok(hd) => map(fun tl -> hd::tl, list(elem_decoder))(Array(tail))
    | Err(e) => Err(e)
        end
        | [] => Ok([])
        end
    | _ => Err("Cannot unpack value as a List")
    end
in

# Decode a JSON object into a Hazel dictionary. #
# For now, a dictionary is just a list of key-value pairs #
# dict : Decoder -> Decoder #
let dict : Decoder -> Decoder =
  fun value_decoder ->
    fun v ->
        case v 
        | Object(pairs) =>
            case pairs
            | (key, value)::tail =>
                case value_decoder(value) 
                | Ok(hd)=> map(fun tl -> (key, hd)::tl, dict(value_decoder))(Object(tail))
                | Err(e) => Err(e)
                end
            | [] => Ok([])
            end
        | _ => Err("Cannot unpack value as a dict")
        end
in


### List ###

# Add an element to the front of a list. #
# cons: (a, [a]) -> [a] #
let cons: (a, [a]) -> [a] = fun x, xs -> x::xs in

# Determine the length of a list. #
# length: [a] -> Int #
let length: [a] -> Int =
  fun xs ->
    case xs
    | [] => 0
    | _::xs => 1 + length(xs) end in

# Extract the first element of a list. #
# hd: [a] -> Option #
let hd: [a] -> Option =
  fun l ->
    case l
    | [] => None
    | x::xs => Some(x) end in

# Extract the rest of the list. #
# tl: [a] -> [a] #
let tl: [a] -> [a] =
  fun l ->
    case l
    | [] => []
    | x::xs => xs end in

# Determine if a list is empty. #
# is_empty: [a] -> Bool #
let is_empty: [a] -> Bool =
  fun xs ->
    case xs
    | [] => true
    | _::_ => false end in

# Return the element at the index. #
# nth: ([a], Int) -> Option #
let nth: ([a], Int) -> Option =
  fun xs, n ->
    case xs, n
    | x::_, 0 => Some(x)
    | _::xs, n => nth(xs, n - 1)
    | [], _ => None end in

# Reverse a List. #
# rev: [a] -> [a] #
let rev: [a] -> [a] =
fun l -> 
let go: ([a], [a]) -> [a] =
  fun xs, acc -> 
    case xs 
    | [] => acc 
    | x::xs => go(xs, x::acc) end in
go(l, []) in

# Check if two lists are equal #
# equal: ((a, a) -> Bool, [a], [a]) -> Bool #
let equal: ((a, a) -> Bool, [a], [a]) -> Bool =
    fun p, xs, ys ->
    case xs, ys
    | [], [] => true
    | x::xs, y::ys => p(x, y) && equal(p, xs, ys)
    | _ => false end
in

# Initialize a list with a given length using an initializer function #
# init: (Int, Int -> a) -> [a] #
let init: (Int, Int -> a) -> [a] =
    fun len, f ->
        let go: (Int, [a]) -> [a] =
        fun idx, xs ->
            if idx < len 
            then go(idx + 1, xs @ [f(idx)])   
            else xs
        in
        go(0, [])
in

# Reduce a list from the left. #
# fold_left: ((b, a) -> b, b, [a]) -> b #
let fold_left: ((b, a) -> b, b, [a]) -> b =
  fun f, acc, xs ->
    case xs 
    | [] => acc
    | hd::tl => fold_left(f, f(acc, hd), tl) end in

# Reduce a list from the right. #
# fold_right: ((a, b) -> b, [a], b) -> b #
let fold_right: ((a, b) -> b, [a], b) -> b =
  fun f, xs, acc ->
    case xs
    | [] => acc
    | hd::tl => f(hd, fold_right(f, tl, acc)) end in

# A simplified lambda calculus expression containing variables, lambdas, and applications #
type Exp =
  + Var(String)
  + Lam(String, Exp)
  + Ap(Exp, Exp)
in
# Evaluation can result in either an Exp or an Error #
# Evaluation by substitution #
# eval: Exp -> Result #
Command
./main \
    --grammar-file grammar.gbnf \
    -t 10 \
    -ngl 64 \
    -b 512 \
    -m ../models/codellama-34b.Q5_K_M.gguf \
    --color -c 3400 \
    --temp 0.7 \
    --repeat_penalty 1.1 \
    -n -1 \
    -f prompt.txt

exp_eval_grammar_stack_size

AlienKevin avatar Dec 01 '23 17:12 AlienKevin

I found similar exponential slowdown as mentioned by @shroominic for my use case, which is to generate code in a language similar to OCaml. The speed of generation was very fast at the first 200 tokens but increased to more than 400 seconds per token as I approach 300 tokens.

I plotted the grammar stack size and duration per token over time and found stack size to be the main factor in the slow down. The number of grammar stacks can go up 800K for my grammar. I'm not very familiar with the grammar sampling algorithm used in llama.cpp but I suspect it's exponential in the length of the parsed string. Polynomially bounded parsing algorithms like the Earley parser might help avoid the exponential blowup.

After looking into the code, I think there's a seemingly obvious and much more simple way to optimize grammar sampling even without threading.

Right now, it manually checks all token candidates and removes any candidates that would violate the grammar.

It would be much more effective and simple to simply sample the normal way, check if the chosen token violates the grammar before proceeding with it, and if it violates the grammar, it should revert to the current behavior that 'forces' the grammar. Right now, it's 'forcing' the sample set to always match before picking. This is resource intensive for large vocabulary models and is highly unnecessary as the model will naturally adopt the grammar most of the time with typical sampler settings (especially with Min P / low temp), and the new behavior would only need to run the full grammar calculations some of the time.

@ejones Any suggestions for how I would go about implementing a solution?

kalomaze avatar Dec 03 '23 08:12 kalomaze

I'm doing some investigation. I think the easiest way to do this without refactors to the grammar itself is by running a check to the existing grammar function with only the single candidate in sampling.cpp; if it's correct, we proceed. If it's wrong, we restart sampling, this time running:

    if (ctx_sampling->grammar != NULL) {
        llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
    }

Before the rep pen or any other modifications are made to the logits. I plan to achieve this by making a copy of the initial logits for the "2nd pass".

kalomaze avatar Dec 03 '23 09:12 kalomaze

There have been a few reports where the grammar sampling can significantly degrade the performance. It would be nice to profile and optimize the implementation - there should be room for improvements.

Already on-going efforts:

Probably worth looking in multi-threading the implementation as well.

https://github.com/ggerganov/llama.cpp/pull/4306

I have made a pull request which should reduce the number of checks necessary to 1 for most tokens instead of all 32,000 tokens in the vocabulary. I have not evaluated whether or not it is actually faster yet, but I'm guessing that avoiding thousands of UTF8 decoding steps for most tokens would improve performance.

kalomaze avatar Dec 03 '23 10:12 kalomaze

@AlienKevin thanks for investigating! Yeah it's a simple top-down backtracking parser so it can be exponential in the worst case. It works best for grammars with little or no ambiguity or nondeterminism. A deterministic grammar should maintain a constant stack count. This isn't obvious though and we could probably do a better job signaling this to the user.

@kalomaze looks great, commented on your PR as well.

ejones avatar Dec 06 '23 03:12 ejones

Grammar processing appears to be quite slow (again?): https://github.com/ggerganov/llama.cpp/pull/4306#issuecomment-1947021051

txbm avatar Feb 15 '24 19:02 txbm

No issue on my end

ExtReMLapin avatar Feb 15 '24 19:02 ExtReMLapin

I've noticed it varies widely with respect to prompt complexity. My JSON schema -> grammar contains three levels of object-arrays and if I ask for a shorter output it completes reasonably quickly with the conforming schema and runs at a consistently high level of CPU utilization.

But if I ask for an output that is about ten times longer, for the exact same schema, I notice the resource utilization (CPU mainly) becomes highly variable and rarely sustains max utilization. The overall inference time gets long enough that it's not worth waiting for the task to complete (30+ minutes) whereas in contrast the exact same prompt will run for about 7 minutes consistently with the grammar/schema removed. If I need to just post a full repro I'm happy to link a Gist.

txbm avatar Feb 15 '24 22:02 txbm

This issue was closed because it has been inactive for 14 days since being marked as stale.

github-actions[bot] avatar Apr 03 '24 01:04 github-actions[bot]

Just to close the loop on my previous comment-- I continued experimenting with this feature on a wide variety of cases and ultimately concluded that the performance variance is too large for production use, even on fast GPUs such as the A100s.

I would, however, very much like to see this feature perform consistently enough for production as it is otherwise very useful. I am happy to help with testing or reproduction of test cases if anyone decides to work on this.

txbm avatar Apr 04 '24 20:04 txbm

I've been digging into this lately, and I've been using the integration tests in #6472 to do some crude performance profiling.

I've definitely seen the sort of dramatic stack expansion that @AlienKevin is talking about. I think there are many causes, but one that I've been digging into is how easy it is for alternate grammars to dramatically inflate the stack. For instance, imagine a grammar that says:

root ::= [0-9] ("a"? "a"? "a"? "a"? "a"? [0-9])*

This is a rather extreme example, but hopefully it illustrates the point.

If you have an input string that looks like "1a2a3a4a5", then at the first character, there is only 1 stack. But it doesn't know if the first 'a' matches the first a? in the grammar, or one of the later ones -- so it needs to track all 5 possibilities. We now have 7 stacks. Which then grows to 15, 35, 75, etc.

Parsing character 0 ('1'), stack size 1
Parsing character 1 ('a'), stack size 7
Parsing character 2 ('2'), stack size 15
Parsing character 3 ('a'), stack size 35
Parsing character 4 ('3'), stack size 75
Parsing character 5 ('a'), stack size 175
Parsing character 6 ('4'), stack size 375
Parsing character 7 ('a'), stack size 875
Parsing character 8 ('5'), stack size 1875

If we change our input string to "1aa2aa3aa4aa5", then it gets even worse, because it permutates a bit:

Parsing character 0 ('1'), stack size 1
Parsing character 1 ('a'), stack size 7
Parsing character 2 ('a'), stack size 15
Parsing character 3 ('2'), stack size 20
Parsing character 4 ('a'), stack size 70
Parsing character 5 ('a'), stack size 150
Parsing character 6 ('3'), stack size 200
Parsing character 7 ('a'), stack size 700
Parsing character 8 ('a'), stack size 1500
Parsing character 9 ('4'), stack size 2000
Parsing character 10 ('a'), stack size 7000
Parsing character 11 ('a'), stack size 15000
Parsing character 12 ('5'), stack size 20000

On my laptop (Macbook M1, 32gb RAM), this noticeably lags the machine, and it hitches to execute even this (relatively) short grammar and validation string. I've done tests with much larger grammars that push MUCH larger stack sizes, and the ambiguity can really drag things down -- even to the point of memory exhaustion.

Currently I'm tackling some left-recursion issues that can cause segfaults ( #6492 ), but after I get done with that (or give up), then I'm going to tackle these performance issues related to ambiguity. I'm not entirely sure, but I think that there should be a viable algorithm to prune near-redundant stacks if their outcomes would be equivalent. I.E., if we've got four potential "a"'s to match, and only one "a", then it doesn't matter which one we choose -- the others can be discarded once we get to the next token, so long as they all converge onto the same spot in the same rule.

This isn't the only performance-related issue that grammars are seeing, but I believe that these massive exponential growths in the stack size is one of biggest opportunities for optimization.

Anyways, that's where I'm at on things -- will keep y'all posted.

HanClinto avatar Apr 04 '24 20:04 HanClinto

I'm not very familiar with the current setup of our CI performance profilers -- if I were to make improvements to the grammar engine, would those speed improvements show up in our current bank of benchmarks?

HanClinto avatar Apr 04 '24 20:04 HanClinto

if I were to make improvements to the grammar engine, would those speed improvements show up in our current bank of benchmarks?

We don't have benchmarks for this yet. You will have to do some custom profiling to determine how the performance changes. With time, we can attempt to add some sort of speed information about the grammar to the CI or at least some local tools.

ggerganov avatar Apr 05 '24 18:04 ggerganov

how easy it is for alternate grammars to dramatically inflate the stack. For instance, imagine a grammar that says: root ::= [0-9] ("a"? "a"? "a"? "a"? "a"? [0-9])*

@HanClinto One possibly big source of such explosive repetitions is JSON grammars w/ minItems/maxItems (or w/ JSON string regexp patterns such as {"type": "string", "pattern": "a{3,10}"}). The easy workaround rn is to rewrite the rule as:

root ::= [0-9] (("a" ("a" ("a" ("a" ("a")?)?)?)?)? [0-9])*

I'm working on fixing the JSON grammar conversion to do this (e.g. https://github.com/ggerganov/llama.cpp/commit/375f85dd573915eb758e4607b156a59e8cd0dbf6), hope to send a PR soon. I'll probably update the GBNF doc w/ performance caveats in the same PR.

ochafik avatar Apr 08 '24 11:04 ochafik

@ochafik That indeed is a massive improvement! Testing your grammar against 1a2a3a4a5 gives:

Parsing character 0 ('1'), stack size 1
Parsing character 1 ('a'), stack size 3
Parsing character 2 ('2'), stack size 2
Parsing character 3 ('a'), stack size 3
Parsing character 4 ('3'), stack size 2
Parsing character 5 ('a'), stack size 3
Parsing character 6 ('4'), stack size 2
Parsing character 7 ('a'), stack size 3
Parsing character 8 ('5'), stack size 2

And testing against the worse case of 1aa2aa3aa4aa5 gives:

Parsing character 0 ('1'), stack size 1
Parsing character 1 ('a'), stack size 3
Parsing character 2 ('a'), stack size 2
Parsing character 3 ('2'), stack size 2
Parsing character 4 ('a'), stack size 3
Parsing character 5 ('a'), stack size 2
Parsing character 6 ('3'), stack size 2
Parsing character 7 ('a'), stack size 3
Parsing character 8 ('a'), stack size 2
Parsing character 9 ('4'), stack size 2
Parsing character 10 ('a'), stack size 3
Parsing character 11 ('a'), stack size 2
Parsing character 12 ('5'), stack size 2

Indeed, that's a massive savings. Nice speedup!

It would still be nice to find ways to speed up the grammar tree navigation even with more poorly written grammars, but improving the quality of the grammars in this way is a huge help.

HanClinto avatar Apr 08 '24 13:04 HanClinto

ways to speed up the grammar tree navigation even with more poorly written grammars

@HanClinto I'd be inclined to detect some easily rewritable grammar cases on the fly and explode when the grammar becomes too combinatorial (w/ a link to a "performance tips" section of the GBNF wiki page), either with a cap on stack size or some builtin timeout maybe?

Maybe also some new features like numbered repetition operators "a"{,5} (desugared as the grammar above) and maybe other regexp derived syntax features (reluctant/eager modifiers?) could make it easier to write efficient grammars.

ochafik avatar Apr 08 '24 14:04 ochafik

Fwiw, I've wanted bounded repetitions a few times with this grammar; recursive white space sometimes let's the model spin forever.

It would also be great to see some stats on the grammar as well, either after running or after desugaring so that we can optimize grammars.

o1lo01ol1o avatar Apr 08 '24 14:04 o1lo01ol1o

@HanClinto I'd be inclined to detect some easily rewritable grammar cases on the fly and explode when the grammar becomes too combinatorial (w/ a link to a "performance tips" section of the GBNF wiki page), either with a cap on stack size or some builtin timeout maybe?

This sounds really good. Detecting and gracefully exiting is the first step. Worst case is that things explode in an infinite loop (as is the case with left-recursion, as in #6492), and implementing a max stack size of 1024 or something should at least give a reasonable approach.

Maybe also some new features like numbered repetition operators "a"{,5} (desugared as the grammar above) and maybe other regexp derived syntax features (reluctant/eager modifiers?) could make it easier to write efficient grammars.

Yeah, I think that could be pretty reasonable. I'm also wondering about parse_sequence() in grammar-parser.cpp, and this section of rewriting +*? operators:

// apply transformation to previous symbol (last_sym_start to end) according to
// rewrite rules:
// S* --> S' ::= S S' |
// S+ --> S' ::= S S' | S
// S? --> S' ::= S |

Makes me wonder if there is a better way to do this (more akin to what you wrote above) but I'm still relatively new to grammar engines.

I'm learning a lot about grammar parsing as part of this exercise -- I never read the dragon book, but I'm wondering if I should order myself a copy as part of this work. :)

Fwiw, I've wanted bounded repetitions a few times with this grammar; recursive white space sometimes let's the model spin forever. @o1lo01ol1o Yeah -- poor handling of whitespace seems to be an issue.

Similar to what @ochafik is saying, I wonder if a good first step would be to do an optimization step where -- after expanding grammars -- that adjacent optional tokens are always collapsed into something more efficient. I.E., ws? ws? being condensed into something like ws{,2}?

HanClinto avatar Apr 08 '24 14:04 HanClinto

I'm also wondering about parse_sequence() in grammar-parser.cpp, and this section of rewriting +*? operators: // S* --> S' ::= S S' | ... Makes me wonder if there is a better way to do this (more akin to what you wrote above) but I'm still relatively new to grammar engines.

@HanClinto I don't think these rewrites are problematic, as they don't change the number of ways things can be parsed. I'm not sure I understand the grammar decoding logic but it seems it may be keeping the stacks for all the possible ways to decode the current sequence (~~I'd speculate it's done to avoid backtracking as we don't want to ungenerate tokens~~ Edit: that's so we get the union of all next tokens that may match any of the possible ways of parsing; prime target for acceleration by a quantum computer? 🤪). The a? a? a? a? a? rule has many possible ways to parse / generate the same thing (e.g. 'aaa' might require 10 stacks, cf. combinations).

That parse_sequence seems like a good place to implement the bounded repetition operator tho, you could desugar the x{min,max} syntax w/ code similar to this build_repetitions maybe.

I'm learning a lot about grammar parsing as part of this exercise -- I never read the dragon book, but I'm wondering if I should order myself a copy as part of this work. :)

Haven't read this one but +1000 to learning about parsers and compilers before LLMs completely take that away from us 😅

ochafik avatar Apr 08 '24 23:04 ochafik

I've taken a deeper look at how grammar sampling works and besides some possible tactical improvements (e.g. extra reserve calls & reusing stacks vectors combined give over 20% speedup on my test grammars), I'm trying to implement some classical parser optimizations (e.g. precomputing the « head set » of character ranges allowed recursively at the start of each rule alternative, so as to avoid instantiating most of the alternatives; currently we "preheat" stacks for all possible upcoming alternatives before consuming chars, instead we could add only the stacks for alternatives that would accept each char in their head sets). Hope to send PRs in the next couple of weeks if the results are good enough 🤞.

ochafik avatar Apr 11 '24 12:04 ochafik

I was able to make a pretty big speed improvement last night in the case of ambiguous alternate grammars. In the case of ambiguous grammars, the stacks are duplicated, and we can prune duplicates in order to trim the stack sizes.

The grammar that I'm using for all of these tests is root ::= [0-9] ("a"? "a"? "a"? "a"? "a"? "a"? "a"? "a"? [0-9])*

This is what the stacks look like in a pretty simple case:

Grammar stack progression for `1aa2` without culling:
Parsing character 0 ('1'), stack size 1
  Stack 0: 3 48,
Parsing character 1 ('a'), stack size 10
  Stack 0: 2 10, 2 3, 3 97,
  Stack 1: 2 10, 2 4, 3 97,
  Stack 2: 2 10, 2 5, 3 97,
  Stack 3: 2 10, 2 6, 3 97,
  Stack 4: 2 10, 2 7, 3 97,
  Stack 5: 2 10, 2 8, 3 97,
  Stack 6: 2 10, 2 9, 3 97,
  Stack 7: 2 10, 3 48, 3 97,
  Stack 8: 2 10, 3 48,
  Stack 9:
Parsing character 2 ('a'), stack size 36
  Stack 0: 2 10, 2 4, 3 97,
  Stack 1: 2 10, 2 5, 3 97,
  Stack 2: 2 10, 2 6, 3 97,
  Stack 3: 2 10, 2 7, 3 97,
  Stack 4: 2 10, 2 8, 3 97,
  Stack 5: 2 10, 2 9, 3 97,
  Stack 6: 2 10, 3 48, 3 97,
  Stack 7: 2 10, 3 48,
  Stack 8: 2 10, 2 5, 3 97,
  Stack 9: 2 10, 2 6, 3 97,
  Stack 10: 2 10, 2 7, 3 97,
  Stack 11: 2 10, 2 8, 3 97,
  Stack 12: 2 10, 2 9, 3 97,
  Stack 13: 2 10, 3 48, 3 97,
  Stack 14: 2 10, 3 48,
  Stack 15: 2 10, 2 6, 3 97,
  Stack 16: 2 10, 2 7, 3 97,
  Stack 17: 2 10, 2 8, 3 97,
  Stack 18: 2 10, 2 9, 3 97,
  Stack 19: 2 10, 3 48, 3 97,
  Stack 20: 2 10, 3 48,
  Stack 21: 2 10, 2 7, 3 97,
  Stack 22: 2 10, 2 8, 3 97,
  Stack 23: 2 10, 2 9, 3 97,
  Stack 24: 2 10, 3 48, 3 97,
  Stack 25: 2 10, 3 48,
  Stack 26: 2 10, 2 8, 3 97,
  Stack 27: 2 10, 2 9, 3 97,
  Stack 28: 2 10, 3 48, 3 97,
  Stack 29: 2 10, 3 48,
  Stack 30: 2 10, 2 9, 3 97,
  Stack 31: 2 10, 3 48, 3 97,
  Stack 32: 2 10, 3 48,
  Stack 33: 2 10, 3 48, 3 97,
  Stack 34: 2 10, 3 48,
  Stack 35: 2 10, 3 48,
Parsing character 3 ('2'), stack size 84
  Stack 0: 2 10, 2 5, 3 97,
  Stack 1: 2 10, 2 6, 3 97,
  Stack 2: 2 10, 2 7, 3 97,
  Stack 3: 2 10, 2 8, 3 97,
  Stack 4: 2 10, 2 9, 3 97,
  Stack 5: 2 10, 3 48, 3 97,
  Stack 6: 2 10, 3 48,
  Stack 7: 2 10, 2 6, 3 97,
  Stack 8: 2 10, 2 7, 3 97,
  Stack 9: 2 10, 2 8, 3 97,
  Stack 10: 2 10, 2 9, 3 97,
  Stack 11: 2 10, 3 48, 3 97,
  Stack 12: 2 10, 3 48,
  Stack 13: 2 10, 2 7, 3 97,
  Stack 14: 2 10, 2 8, 3 97,
  Stack 15: 2 10, 2 9, 3 97,
  Stack 16: 2 10, 3 48, 3 97,
  Stack 17: 2 10, 3 48,
  Stack 18: 2 10, 2 8, 3 97,
  Stack 19: 2 10, 2 9, 3 97,
  Stack 20: 2 10, 3 48, 3 97,
  Stack 21: 2 10, 3 48,
  Stack 22: 2 10, 2 9, 3 97,
  Stack 23: 2 10, 3 48, 3 97,
  Stack 24: 2 10, 3 48,
  Stack 25: 2 10, 3 48, 3 97,
  Stack 26: 2 10, 3 48,
  Stack 27: 2 10, 3 48,
  Stack 28: 2 10, 2 6, 3 97,
  Stack 29: 2 10, 2 7, 3 97,
  Stack 30: 2 10, 2 8, 3 97,
  Stack 31: 2 10, 2 9, 3 97,
  Stack 32: 2 10, 3 48, 3 97,
  Stack 33: 2 10, 3 48,
  Stack 34: 2 10, 2 7, 3 97,
  Stack 35: 2 10, 2 8, 3 97,
  Stack 36: 2 10, 2 9, 3 97,
  Stack 37: 2 10, 3 48, 3 97,
  Stack 38: 2 10, 3 48,
  Stack 39: 2 10, 2 8, 3 97,
  Stack 40: 2 10, 2 9, 3 97,
  Stack 41: 2 10, 3 48, 3 97,
  Stack 42: 2 10, 3 48,
  Stack 43: 2 10, 2 9, 3 97,
  Stack 44: 2 10, 3 48, 3 97,
  Stack 45: 2 10, 3 48,
  Stack 46: 2 10, 3 48, 3 97,
  Stack 47: 2 10, 3 48,
  Stack 48: 2 10, 3 48,
  Stack 49: 2 10, 2 7, 3 97,
  Stack 50: 2 10, 2 8, 3 97,
  Stack 51: 2 10, 2 9, 3 97,
  Stack 52: 2 10, 3 48, 3 97,
  Stack 53: 2 10, 3 48,
  Stack 54: 2 10, 2 8, 3 97,
  Stack 55: 2 10, 2 9, 3 97,
  Stack 56: 2 10, 3 48, 3 97,
  Stack 57: 2 10, 3 48,
  Stack 58: 2 10, 2 9, 3 97,
  Stack 59: 2 10, 3 48, 3 97,
  Stack 60: 2 10, 3 48,
  Stack 61: 2 10, 3 48, 3 97,
  Stack 62: 2 10, 3 48,
  Stack 63: 2 10, 3 48,
  Stack 64: 2 10, 2 8, 3 97,
  Stack 65: 2 10, 2 9, 3 97,
  Stack 66: 2 10, 3 48, 3 97,
  Stack 67: 2 10, 3 48,
  Stack 68: 2 10, 2 9, 3 97,
  Stack 69: 2 10, 3 48, 3 97,
  Stack 70: 2 10, 3 48,
  Stack 71: 2 10, 3 48, 3 97,
  Stack 72: 2 10, 3 48,
  Stack 73: 2 10, 3 48,
  Stack 74: 2 10, 2 9, 3 97,
  Stack 75: 2 10, 3 48, 3 97,
  Stack 76: 2 10, 3 48,
  Stack 77: 2 10, 3 48, 3 97,
  Stack 78: 2 10, 3 48,
  Stack 79: 2 10, 3 48,
  Stack 80: 2 10, 3 48, 3 97,
  Stack 81: 2 10, 3 48,
  Stack 82: 2 10, 3 48,
  Stack 83: 2 10, 3 48,

But by adding a culling step to remove duplicates after each time we advance the stacks, we get a much more reasonable stack progression:

Grammar stack progression for `1aa2` without culling:
Parsing character 0 ('1'), stack size 1
  Stack 0: 3 48,
Before culling, stack size 10
After culling, stack size 10
Parsing character 1 ('a'), stack size 10
  Stack 0: 2 10, 2 3, 3 97,
  Stack 1: 2 10, 2 4, 3 97,
  Stack 2: 2 10, 2 5, 3 97,
  Stack 3: 2 10, 2 6, 3 97,
  Stack 4: 2 10, 2 7, 3 97,
  Stack 5: 2 10, 2 8, 3 97,
  Stack 6: 2 10, 2 9, 3 97,
  Stack 7: 2 10, 3 48, 3 97,
  Stack 8: 2 10, 3 48,
  Stack 9:
Before culling, stack size 36
After culling, stack size 8
Parsing character 2 ('a'), stack size 8
  Stack 0: 2 10, 2 4, 3 97,
  Stack 1: 2 10, 2 5, 3 97,
  Stack 2: 2 10, 2 6, 3 97,
  Stack 3: 2 10, 2 7, 3 97,
  Stack 4: 2 10, 2 8, 3 97,
  Stack 5: 2 10, 2 9, 3 97,
  Stack 6: 2 10, 3 48, 3 97,
  Stack 7: 2 10, 3 48,
Before culling, stack size 28
After culling, stack size 7
Parsing character 3 ('2'), stack size 7
  Stack 0: 2 10, 2 5, 3 97,
  Stack 1: 2 10, 2 6, 3 97,
  Stack 2: 2 10, 2 7, 3 97,
  Stack 3: 2 10, 2 8, 3 97,
  Stack 4: 2 10, 2 9, 3 97,
  Stack 5: 2 10, 3 48, 3 97,
  Stack 6: 2 10, 3 48,
Before culling, stack size 10
After culling, stack size 10

And it runs MUCH faster.

Testing against a much larger example of 1aa2aa3aa4aa5aa6, the stack size grew to 51631104 by the end, and I was afraid that it wasn't going to finish.

Grammar stack progression for long string, unoptimized
Parsing character 0 ('1'), stack size 1
Parsing character 1 ('a'), stack size 10
Parsing character 2 ('a'), stack size 36
Parsing character 3 ('2'), stack size 84
Parsing character 4 ('a'), stack size 280
Parsing character 5 ('a'), stack size 1008
Parsing character 6 ('3'), stack size 2352
Parsing character 7 ('a'), stack size 7840
Parsing character 8 ('a'), stack size 28224
Parsing character 9 ('4'), stack size 65856
Parsing character 10 ('a'), stack size 219520
Parsing character 11 ('a'), stack size 790272
Parsing character 12 ('5'), stack size 1843968
Parsing character 13 ('a'), stack size 6146560
Parsing character 14 ('a'), stack size 22127616
Parsing character 15 ('6'), stack size 51631104

After culling optimization:

Grammar stack progression for long string, optimized
Parsing character 0 ('1'), stack size 1
Parsing character 1 ('a'), stack size 10
Parsing character 2 ('a'), stack size 8
Parsing character 3 ('2'), stack size 7
Parsing character 4 ('a'), stack size 10
Parsing character 5 ('a'), stack size 8
Parsing character 6 ('3'), stack size 7
Parsing character 7 ('a'), stack size 10
Parsing character 8 ('a'), stack size 8
Parsing character 9 ('4'), stack size 7
Parsing character 10 ('a'), stack size 10
Parsing character 11 ('a'), stack size 8
Parsing character 12 ('5'), stack size 7
Parsing character 13 ('a'), stack size 10
Parsing character 14 ('a'), stack size 8
Parsing character 15 ('6'), stack size 7

And it runs near-instantly, as opposed to over a minute for the previous.

This won't help in other cases, but I'm pretty happy with this culling step, and I hope to submit a PR for this improvement soon.

HanClinto avatar Apr 11 '24 13:04 HanClinto

BTW, I spent this evening doing two different experiments for optimizations -- both turned up zero improvement.

The first was attempting to modify reject_candidates to modify the candidates array in-place to reduce allocation of vectors, but that still has a bug in it and I haven't fully tracked it down yet. I'm optimistic this might still bear some fruit at some point.

My second attempt was to memoize llama_grammar_advance_stack, but I didn't see a speed improvement for it. It's possible that it will help more on larger sampling lengths, but I'm not hopeful.

Overall not much fruit for the evening -- we'll see what the rest of the weekend brings.

HanClinto avatar Apr 13 '24 04:04 HanClinto

The first was attempting to modify reject_candidates to modify the candidates array in-place to reduce allocation of vectors, but that still has a bug in it and I haven't fully tracked it down yet. I'm optimistic this might still bear some fruit at some point. My second attempt was to memoize llama_grammar_advance_stack, but I didn't see a speed improvement for it. It's possible that it will help more on larger sampling lengths, but I'm not hopeful.

@HanClinto I've tried an embarrassingly large amount of variants of these ideas too... In the end, good old profiling helped me identify what seem to be two big easy wins → https://github.com/ggerganov/llama.cpp/pull/6811 (makes grammar sampling itself up to 10x faster)

(show basic commands to profile on Mac)
brew install gperftools
make clean && make -j LLAMA_GPROF=1 main
CPUPROFILE=llama.prof \
  DYLD_INSERT_LIBRARIES=/opt/homebrew/Cellar/gperftools/2.15/lib/libprofiler.dylib \
  ./main ...

pprof --pdf main llama.prof > main.pdf
open main.pdf

ochafik avatar Apr 21 '24 17:04 ochafik

Why not use ANTLR? They have more grammars and the llama.cpp, c grammar is missing a lot.

ANTLR 4: https://github.com/antlr/grammars-v4/blob/master/c/C.g4 llama: https://github.com/ggerganov/llama.cpp/blob/master/grammars/c.gbnf

liljohnak avatar Apr 22 '24 23:04 liljohnak

@liljohnak ANTLR generates parsers that operate on a given string and backtrack when they picked the wrong branch if/when the grammar is ambiguous. In contrast, llama.cpp's grammar constrained sampling (https://github.com/ggerganov/llama.cpp/pull/1773) lets the model pick whichever next token would result in a string partially parsable from the start. To that effect, it keeps all the possible parsing stacks around (they're continuation stacks, their back points to the next char element that can be parsed).

As for their grammars, they can easily be converted to GBNF, for instance, I've let Opus have a go at the C one: c_antlr.gbnf... but it makes llama.cpp to segfault for now (~~too many stacks?~~ edit probably just left recursion (https://github.com/ggerganov/llama.cpp/issues/6492); new frontier to conquer!).

ochafik avatar Apr 23 '24 10:04 ochafik