tokenizers icon indicating copy to clipboard operation
tokenizers copied to clipboard

Space after unnormalized token is added when `use_fast=True` for Llama tokenizer

Open Butanium opened this issue 1 year ago • 10 comments

Related to: https://github.com/huggingface/transformers/issues/25073

In my current project, I'd like to add a special token that doesn't insert a space to the next token. Currently, I need to specify use_fast=False in order for this to work. However:

  • This is unclear to me why I should expect slow and fast tokenizer to behave in different ways
  • This finding doesn't generalize to e.g. gemma tokenizer which never add such space
  • Maybe this should be an explicit option in the tokenizer_kwargs?
from transformers import AutoTokenizer
fast_tokenizer = AutoTokenizer.from_pretrained("meta-llama/llama-2-7b-hf", use_fast=True)
slow_tokenizer = AutoTokenizer.from_pretrained("meta-llama/llama-2-7b-hf", use_fast=False)
tok = fast_tokenizer.bos_token
s = f'a:{tok}->'
print(f"fast: {fast_tokenizer.tokenize(s)}\nslow: {slow_tokenizer.tokenize(s)}")
>>> fast: ['▁a', ':', '<s>', '▁->']
>>> slow: ['▁a', ':', '<s>', '->']

Butanium avatar Aug 14 '24 11:08 Butanium

Oh wait @ArthurZucker is that what you're fixing here in https://github.com/huggingface/tokenizers/pull/1568 ?

Butanium avatar Aug 14 '24 11:08 Butanium

Same issue with unnormalized non-special tokens:

from tokenizers import AddedToken
from transformers import AutoTokenizer
tok_name = "meta-llama/llama-2-7b-hf"
fast_tokenizer = AutoTokenizer.from_pretrained(tok_name, use_fast=True)
slow_tokenizer = AutoTokenizer.from_pretrained(tok_name, use_fast=False)
tok = "<special>"
t = AddedToken(
    tok, normalized=False, special=False
 )
fast_tokenizer.add_tokens([t])
slow_tokenizer.add_tokens([t])
s = f'hello:{tok}->'
print(f"fast: {fast_tokenizer.tokenize(s)}\nslow: {slow_tokenizer.tokenize(s)}")
>>> fast: ['▁hello', ':', '<special>', '▁->']
>>> slow: ['▁hello', ':', '<special>', '->']

Butanium avatar Aug 14 '24 12:08 Butanium

And there is even more differences when you add normalized=True for special tokens ...

from tokenizers import AddedToken
from transformers import AutoTokenizer
tok_name = "meta-llama/llama-2-7b-hf"
fast_tokenizer = AutoTokenizer.from_pretrained(tok_name, use_fast=True)
slow_tokenizer = AutoTokenizer.from_pretrained(tok_name, use_fast=False)
tok = "<special>"
t = AddedToken(
    tok, normalized=True, special=True
 )
fast_tokenizer.add_tokens([t], special_tokens=True)
slow_tokenizer.add_tokens([t], special_tokens=True)
s = f'hello:{tok}->'
print(f"fast: {fast_tokenizer.tokenize(s)}\nslow: {slow_tokenizer.tokenize(s)}")
>>> fast: ['▁hello', ':', '<', 'special', '>', '->']
>>> slow: ['▁hello', ':', '<special>', '->']

Butanium avatar Aug 14 '24 12:08 Butanium

Also, if you specify the add_prefix_space arg, the tokenizer is actually using the slow implementation which leads to different behavior for the above code! https://github.com/huggingface/transformers/blob/9485289f374d4df7e8aa0ca917dc131dcf64ebaf/src/transformers/models/llama/tokenization_llama_fast.py#L154

Butanium avatar Aug 14 '24 13:08 Butanium

No this was fixed a LONG time ago!

from tokenizers import AddedToken
from transformers import AutoTokenizer
tok_name = "meta-llama/llama-2-7b-hf"
fast_tokenizer = AutoTokenizer.from_pretrained(tok_name, use_fast=True, legacy=False, from_slow=True)
slow_tokenizer = AutoTokenizer.from_pretrained(tok_name, use_fast=False)
tok = "<special>"
t = AddedToken(
    tok, normalized=True, special=True
 )
fast_tokenizer.add_tokens([t], special_tokens=True)
slow_tokenizer.add_tokens([t], special_tokens=True)
s = f'hello:{tok}->'
print(f"fast: {fast_tokenizer.tokenize(s)}\nslow: {slow_tokenizer.tokenize(s)}")
>>> fast: ['▁hello', ':', '<', 'special', '>', '->']
>>> slow: ['▁hello', ':', '<special>', '->']

ArthurZucker avatar Aug 15 '24 17:08 ArthurZucker

See #1357

ArthurZucker avatar Aug 15 '24 17:08 ArthurZucker

Hey @ArthurZucker, thanks for your answer. I'm using 0.19.1 which should have the fix. I'm really confused right now. Why isn't the fact that use_fast alters the behavior of the tokenizer an issue? My more practical question is, is there a way to add a token s.t.:

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model, ...) # magic kwargs
# magically add <special>
s = f'a:<special>->'
print(tokenizer.tokenize(s)})

will always prints [ {whatever}, '<special>', '->'] where the key point here is that there is a -> and not a _-> token

Butanium avatar Aug 26 '24 10:08 Butanium

Yes, what effects this is the legacy flag, as Llama was added before we fixed the issue.

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model, legacy=False) # magic kwargs
# magically add <special>
s = f'a:<special>->'
print(tokenizer.tokenize(s)})

When you set legacy to False you might not always get the conversion from slow, which forces the legacy attribute to be actually taken into account!

ArthurZucker avatar Aug 26 '24 13:08 ArthurZucker

Ok so I should do some unit test and choose different kwarg depending on the tokenizer to get the same behavior?

Butanium avatar Aug 29 '24 23:08 Butanium

No, sorry. Basically you can just check the tokenizer's pre_tokenizer. If it's metaspace, the prepend_scheme should be set to first instead of always

ArthurZucker avatar Aug 30 '24 15:08 ArthurZucker