llama.cpp icon indicating copy to clipboard operation
llama.cpp copied to clipboard

llama.cpp BPE tokenization of wiki.test does not match the HF tokenization

Open ggerganov opened this issue 8 months ago • 9 comments

I did the following test to tokenize wiki.test.raw using our tokenizer and the Python tokenizer. The expectation is that the outputs will match:

# generate ggml-vocab-falcon.gguf
./convert-falcon-hf-to-gguf.py --vocab-only ~/development/huggingface/falcon-7b/ --outfile ./models/ggml-vocab-falcon.gguf

# tokenize using Python
python3 tests/test-tokenizer-0-falcon.py ~/development/huggingface/falcon-7b/ --fname-tok ./build/wikitext-2-raw/wiki.test.raw

# tokenize using llama.cpp
cd build
make -j
./bin/test-tokenizer-0-falcon ../models/ggml-vocab-falcon.gguf ./wikitext-2-raw/wiki.test.raw

# compare the results
cmp ./wikitext-2-raw/wiki.test.raw.tok ./wikitext-2-raw/wiki.test.raw.tokcpp 
./wikitext-2-raw/wiki.test.raw.tok ./wikitext-2-raw/wiki.test.raw.tokcpp differ: char 1, line 1

The results are pretty close, but not exactly the same. Any ideas why the test does not pass? I thought that #3252 would resolve this

cc @goerch

ggerganov avatar Oct 06 '23 13:10 ggerganov

That is a nice test. I made some modifications to get more detailed outputs of the tests and see differences like

  1. Problem with endoftext

image

  1. Non greediness

image

image

goerch avatar Oct 06 '23 14:10 goerch

Intermediate results of debugging: bpe_gpt2_preprocess seems to do the right thing, llm_tokenizer_bpe::tokenize seems to be subtly broken, although it looks very similar to examples/gptneox-wip. Paging @cmp-nct in need for help, because git blame doesn't work very well here.

goerch avatar Oct 06 '23 19:10 goerch

llm_tokenizer_bpe::tokenize seems to be subtly broken

I implemented an independent port of the gpt2-tokenizer(will share the code if someone is interested) and it shows the same behavior as the llama.cpp tokenizer. I also tried to use the slow tokenizer of HF (i.e. the Python implementation) to compare without success, i.e. I didn't get it working (any tips appreciated!). I tried to understand the fast tokenizer implemented here, which looks like it supports some advanced features. This seems hopeless until we are going to debug the complete HF stack (which I currently don't know how to). BTW: I'm not talking about the endoftext problem here.

goerch avatar Oct 07 '23 20:10 goerch

llm_tokenizer_bpe::tokenize seems to be subtly broken

I implemented an independent port of the gpt2-tokenizer(will share the code if someone is interested) and it shows the same behavior as the llama.cpp tokenizer. I also tried to use the slow tokenizer of HF (i.e. the Python implementation) to compare without success, i.e. I didn't get it working (any tips appreciated!). I tried to understand the fast tokenizer implemented here, which looks like it supports some advanced features. This seems hopeless until we are going to debug the complete HF stack (which I currently don't know how to). BTW: I'm not talking about the endoftext problem here.

You can create the slow (or fast) GPT2 tokenizer in tests/test-tokenizer-0-falcon.py like so:

from transformers import GPT2TokenizerFast, GPT2Tokenizer
slow_tokenizer = GPT2Tokenizer(vocab_file=dir_tokenizer + '/vocab.json', merges_file=dir_tokenizer + '/merges.txt')
fast_tokenizer = GPT2TokenizerFast(tokenizer_file=dir_tokenizer + '/tokenizer.json')

You will have to create the files vocab.json and merges.txt yourself. The file vocab.json should contain only the vocab map from Falcon's tokenizer.json (e.g. see https://huggingface.co/gpt2/blob/main/vocab.json). The file merges.txt should contain only the contents of the merges array, one array element per line (i.e. space separated token pairs, e.g. see https://huggingface.co/gpt2/blob/main/merges.txt).

You will notice that the slow tokenizer tokenizes "2000" differently ("20" "00") than the fast one ("200" "0"). So yes, I think we are running into a HF implementation bug, but the cpp code tokenizes like the (presumably now less popular) slow tokenizer.

And the <|endoftext|> in the front is trivial, it's just the artificially injected BOS token (which I believe is a Llama thing and should not be inserted for Falcon).

jploski avatar Oct 09 '23 20:10 jploski

So maybe it is best to switch to the slow tokenizer in test-tokenizer-0-falcon.py and close the issue if things match? Probably also add this as a ctest

ggerganov avatar Oct 10 '23 12:10 ggerganov

I could imagine this to be hairy problem, because I'd assume a couple of models have been trained with the fast tokenizers?

goerch avatar Oct 10 '23 12:10 goerch

I could imagine this to be hairy problem, because I'd assume a couple of models have been trained with the fast tokenizers?

Yes, I suppose everyone uses the fast ones because they are default, so having a tokenizer in llama.cpp which behaves differently is not good.

One point which I am still unclear about is whether the fast tokenizer, which for some reason (also) wants tokenizer.json rather than just the vocab.json/merges.txt file as input maybe relies on some extra information from tokenizer.json which makes it behave differently in our test case. So there's still some chance it might not be a bug in the HF implementation after all, but rather our lack of understanding of it. I'm hoping to learn more from HF's response to https://github.com/huggingface/tokenizers/issues/1363.

jploski avatar Oct 10 '23 12:10 jploski

The discrepancy here is because Falcon's tokenizer.json specifies a different pre_tokenizer.

Most BPE-using models use the config that mimics GPT2 - i.e. "pretokenizing" is done with the standard regex:

  "pre_tokenizer": {
    "type": "ByteLevel",
    "add_prefix_space": false,
    "trim_offsets": true,
    "use_regex": true
  },

replacing with this gets consistent behavior with "slow", but that's not really what we want. Falcon instead specifies this:

"pre_tokenizer": {
  "type": "Sequence",
  "pretokenizers": [
    {
      "type": "Punctuation",
      "behavior": "Contiguous"
    },
    {
      "type": "ByteLevel",
      "add_prefix_space": false,
      "trim_offsets": true,
      "use_regex": true
    },
    {
      "type": "Digits",
      "individual_digits": false
    },
    {
      "type": "Split",
      "pattern": {
        "Regex": "[0-9][0-9][0-9]"
      },
      "behavior": "Isolated",
      "invert": false
    }
  ]
}

that is, it first applies punctuation splitting before the standard regex, then the standard regex, then "Digits" (force spans of digits to be separated from non-digits) , then "isolated" mode custom regex splitting on a regex matching 3 consecutive digits so that no token over 3 digits long makes it past the pretokenizer

its the last one that seems to cause the discrepancy here - but the problem is that to be fully consistent with tokenizers we will need to implement some of the less common pre_tokenizer options

apage43 avatar Oct 10 '23 16:10 apage43

@apage43 cc @ggerganov

Could you take a look at my code? I followed the procedure you outlined and even checked the source code, but I'm still getting inconsistent results. Is there a way to directly test the pre-tokenizer without comparing the final output of the tokenizer? This might help us pinpoint the exact issue. https://github.com/ggerganov/llama.cpp/pull/5613

I'm really confident that my GPT-2 style pre-tokenizer works perfectly. I carefully followed the regex pattern and tested it extensively, using more than 10GB of data that included both synthetic and real examples.

Edit: Ah, I understand now! The = symbol isn't classified under Unicode punctuation!

bobqianic avatar Feb 20 '24 17:02 bobqianic

This issue was closed because it has been inactive for 14 days since being marked as stale.

github-actions[bot] avatar Apr 06 '24 01:04 github-actions[bot]