llama.cpp icon indicating copy to clipboard operation
llama.cpp copied to clipboard

BPE Tokenizer: Multiple newlines doesn't merge into a single token

Open Lyrcaxis opened this issue 10 months ago • 2 comments

So, I found out that \n\n if appended by a character tokenizes as ['\n',\n'] ([198, 198]) instead of ['\n\n'] ([271]). (I'm using Llama3 for this example, but this extends to other models as well)

Here's an example prompt:

<|begin_of_text|><|start_header_id|>system<|end_header_id|>

You're Psy, user's assistant, and a master of concise replies.<|eot_id|><|start_header_id|>user<|end_header_id|>

Write a short poem<|eot_id|><|start_header_id|>assistant<|end_header_id|>


And the tokenized text: image

If I switch the template to use \n\n\n\n (1038) it tokenizes as ['\n\n\n', '\n'] ([1432, 198]): image

(Note: I know there've been efforts in making special tokens render, but rn I understand they don't have a textual representation, so you can ignore tokens like 128000, 128006 and 128007 in the sequences above)

In C# I patch the issue like so:

var tokensCount = NativeApi.llama_tokenize(model, bytesPtr, bytes.Length, tokensPtr, tokenBuffer.Length, add_bos, special);
var list = new List<LLamaToken>();
for (int i = 0; i < tokensCount; i++) { // Hack: ['\n','\n'] --> ['\n\n']
    if (tokenBuffer[i] == 198 && tokenBuffer[i + 1] == 198) { list.Add(271); i++; }
    else { list.Add(tokenBuffer[i]); }
}
return list.ToArray();

(ignoring all \n merges except the \n\n which is common for the template)

Lyrcaxis avatar Apr 21 '24 13:04 Lyrcaxis

I'm also running into this. It seems to degrade performance for llama-3-instruct. (Hackily replacing two new line with the single token improves performance anecdotally)

I'd imagine there are other cases where the tokenization is not as greedy as possible - unsure how this would affect model performance though.

MarcusDunn avatar Apr 22 '24 16:04 MarcusDunn

The bpe_gpt2_preprocess split the string \n\nword in a bit of a strange way: \n, \nword.

See: https://github.com/ggerganov/llama.cpp/pull/5613

bullno1 avatar Apr 23 '24 03:04 bullno1

I am encountering a similar issue. For me, the model likes to generate token .\n (id=627) at the end of the sentence. However, when retokenizing the string subsequently I instead get two disjoint tokens . (id=13), and \n (id=198)

Same thing with various other tokens like .\n\n (id=382) Something is really broken with the merging behavior related to newlines.

@Lyrcaxis I don't think your hack is sufficient. Due to the massive vocab size of llama 3 there are many combinations relating to newlines that the model picks and this bug affects, another one seems to be ---\n (id=11192)

LostRuins avatar Apr 24 '24 12:04 LostRuins

Does anyone know what regex is used by LLaMA 3 to preprocess the text?

In llama.cpp we currently implement just this:

https://github.com/ggerganov/llama.cpp/blob/aa750c1ede6232c91de890a14a7731d6daa2bc8e/llama.cpp#L12128-L12135

My guess is that we need to update it with whatever is used for LLaMA 3

ggerganov avatar Apr 25 '24 11:04 ggerganov

I'm not sure how accurate this is, but here is a possible reference which appears to at least merge \n\n correctly: https://raw.githubusercontent.com/belladoreai/llama3-tokenizer-js/master/llama-tokenizer.js

However, it's not just implemented as a clean regex but appears to have some additional processing too.

LostRuins avatar Apr 25 '24 12:04 LostRuins

I see, this is useful. We'll need to support that. There has been some work started in https://github.com/ggerganov/llama.cpp/pull/6252 to improve BPE preprocessing. I guess we have to prioritize this, since this likely leads to poor generation quality

ggerganov avatar Apr 25 '24 12:04 ggerganov

Does anyone know what regex is used by LLaMA 3 to preprocess the text?

Is this what you'd be looking for?

https://github.com/meta-llama/llama3/blob/af6eedf7042fb51d00b2b26d8ef1ceaab73e1670/llama/tokenizer.py#L47

MarcusDunn avatar Apr 25 '24 19:04 MarcusDunn

I have Llama3 regex implementation.

I did some tests, generating texts (randomly merging strings from tokenizer.json) and comparing encodings to tiktoken's encoding.

The main indea is first annotate all matched character lengths in tokens_length, then build the bpe_encoded_words.

If this is useful, I can do a PR.

    std::vector<std::string> bpe_llama3_preprocess(const std::string & text) {
        // LLAMA3 Regex: "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"
        const auto cpts = unicode_cpts_from_utf8(text);
        const int num_cpts = (int)cpts.size();

        auto _tolower = [] (const int cpt) -> int {
            if ('A' <= cpt && cpt <= 'Z')
                return cpt + ('a'-'A');
            return cpt;
        };

        auto _get_cpt = [&] (const int pos) -> uint32_t {
            return (0 <= pos && pos < num_cpts) ? cpts[pos] : 0;
        };

        auto _get_cpt_type = [&] (const int pos) -> int {
            return (0 <= pos && pos < num_cpts) ? unicode_cpt_type(cpts[pos]) : CODEPOINT_TYPE_UNIDENTIFIED;
        };

        std::vector<int> tokens_length;
        tokens_length.reserve(cpts.size()/3+4);
        int _prev_end = 0;
        auto _add_token = [&] (const int end) -> int {
            GGML_ASSERT(_prev_end <= end && end <= num_cpts);
            int len = end - _prev_end;
            if(len > 0)
                tokens_length.push_back(len);
            _prev_end = end;
            //if( len && true ) {
            //    std::string s = "";
            //    for( int p = end-len; p < end; p++ )
            //        s += unicode_cpt_to_utf8(cpts[p]);
            //    printf( ">>> '%s'\n", s.c_str() );
            //}
            return len;
        };

        int pos = 0;
        while (pos < num_cpts) {

            const uint32_t cpt = _get_cpt(pos);
            const int cpt_type = _get_cpt_type(pos);

            // regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive
            if (cpt == '\'' && pos+1 < num_cpts) {
                uint32_t cpt_next = _tolower(_get_cpt(pos+1));
                if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') {
                    pos += _add_token(pos+2);
                    continue;
                } else if (pos+2 < num_cpts) {
                    uint32_t cpt_next_next = _tolower(_get_cpt(pos+2));
                    if ((cpt_next == 'r' && cpt_next_next == 'e') ||
                        (cpt_next == 'v' && cpt_next_next == 'e') ||
                        (cpt_next == 'l' && cpt_next_next == 'l')) {
                        pos += _add_token(pos+3);
                        continue;
                    }
                }
            }

            // regex: [^\r\n\p{L}\p{N}]?\p{L}+  //####FIXME: the first \p{L} is correct?
            if (cpt != '\r' && cpt != '\n' && /*cpt_type != CODEPOINT_TYPE_LETTER &&*/ cpt_type != CODEPOINT_TYPE_DIGIT) {
                if(cpt_type == CODEPOINT_TYPE_LETTER || _get_cpt_type(pos+1) == CODEPOINT_TYPE_LETTER) {  // one or more letters
                    pos++;
                    while(_get_cpt_type(pos) == CODEPOINT_TYPE_LETTER)
                        pos++;
                    _add_token(pos);
                    continue;
                }
            }

            // regex: \p{N}{1,3}
            if (cpt_type == CODEPOINT_TYPE_DIGIT) {
                int ini = pos;
                while(_get_cpt_type(pos) == CODEPOINT_TYPE_DIGIT) {
                    if (++pos - ini >= 3 ) {
                        _add_token(pos);
                        ini = pos;
                    }
                }
                _add_token(pos);
                continue;
            }

            // regex: <space>?[^\s\p{L}\p{N}]+[\r\n]*
            uint32_t cpt2 = (cpt == ' ' ? _get_cpt(pos+1) : cpt);
            int cpt2_type = (cpt == ' ' ? _get_cpt_type(pos+1) : cpt_type);
            if (cpt2_type != CODEPOINT_TYPE_WHITESPACE && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_DIGIT && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) {
                pos += (cpt == ' ');
                while(cpt2_type != CODEPOINT_TYPE_WHITESPACE && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_DIGIT && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED)
                    cpt2_type = _get_cpt_type(++pos);
                cpt2 = _get_cpt(pos);
                while(cpt2 == '\r' || cpt2 == '\n')
                    cpt2 = _get_cpt(++pos);
                _add_token(pos);
                continue;
            }

            int num_whitespaces = 0;
            int last_pos_r_or_n = -1;
            while (_get_cpt_type(pos+num_whitespaces) == CODEPOINT_TYPE_WHITESPACE) {
                cpt2 = _get_cpt(pos+num_whitespaces);
                if (cpt2 == '\r' || cpt2 == '\n')
                    last_pos_r_or_n = pos+num_whitespaces;
                num_whitespaces++;
            }

            // regex: \s*[\r\n]+
            if (last_pos_r_or_n >= 0) {
                pos = last_pos_r_or_n + 1;
                _add_token(pos);
                continue;
            }

            // regex: \s+(?!\S)
            if(num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != 0) {
                pos += num_whitespaces - 1;
                _add_token(pos);
                continue;
            }

            // regex: \s+
            if(num_whitespaces > 0) {
                pos += num_whitespaces;
                _add_token(pos);
                continue;
            }

            // no matches
            _add_token(++pos);
        }

        GGML_ASSERT(pos == num_cpts);
        _add_token(pos);

        pos = 0;
        std::vector<std::string> bpe_encoded_words(tokens_length.size());
        for (int n = 0; n < (int)tokens_length.size(); n++) {
            std::string &encoded_token = bpe_encoded_words[n];
            const int length = tokens_length[n];
            GGML_ASSERT(length > 0);
            for (int i = 0; i < length; i++) {
                std::string char_utf8 = unicode_cpt_to_utf8(cpts[pos++]);
                for (char c : char_utf8) {
                    encoded_token += unicode_byte_to_utf8(c);
                }
            }
        }

        GGML_ASSERT(pos == num_cpts);
        return bpe_encoded_words;
    }

jaime-m-p avatar Apr 27 '24 20:04 jaime-m-p

The issue should be fixed with #6920

ggerganov avatar Apr 29 '24 14:04 ggerganov

Awesome! Thanks.

Lyrcaxis avatar Apr 29 '24 22:04 Lyrcaxis