tokenizers
tokenizers copied to clipboard
[BUG] Fast tokenizer does not deal with AddedTokens properly(no problem in Transformers python tokenizer impl.)
When I'm trying to add some tokens in vocab, there's 3 issue in Fast
type tokenizers; there's no problem in python tokenizer, though.
- Spacebar before additional token would be deleted if added token was not special token
- If the added token was not a special token, entering additional tokens in a row (without spaces) would prevent the subsequent token from being tokenized
- A single spacebar(ID 28705) between two additional tokens was treated as two (ID 259) when the added token was a special token.
Source code to recall issue
from transformers import (
AutoProcessor,
LlamaTokenizer,
LlamaTokenizerFast,
)
processor = AutoProcessor.from_pretrained(
"HuggingFaceM4/idefics2-8b",
do_image_splitting=True,
)
print(processor.tokenizer)
def test_tokenizer(tokenizer):
# print(f"Tokenizer: {tokenizer}")
print("=======")
test_texts = [
"!@#",
"!@# ",
"!@# <ACTION_1>",
"!@# <ACTION_1> ",
"!@# <ACTION_1> <ACTION_2>",
"!@# <ACTION_1><ACTION_2>",
]
for text in test_texts:
print(f"{text:30}", tokenizer(text))
tokenizer = LlamaTokenizer.from_pretrained("HuggingFaceM4/idefics2-8b")
tokenizer.add_tokens([f"<ACTION_{idx}>" for idx in range(18)])
test_tokenizer(tokenizer)
tokenizer = LlamaTokenizer.from_pretrained("HuggingFaceM4/idefics2-8b")
tokenizer.add_tokens([f"<ACTION_{idx}>" for idx in range(18)], special_tokens=True)
test_tokenizer(tokenizer)
tokenizer = LlamaTokenizerFast.from_pretrained("HuggingFaceM4/idefics2-8b")
tokenizer.add_tokens([f"<ACTION_{idx}>" for idx in range(18)])
test_tokenizer(tokenizer)
tokenizer = LlamaTokenizerFast.from_pretrained("HuggingFaceM4/idefics2-8b")
tokenizer.add_tokens([f"<ACTION_{idx}>" for idx in range(18)], special_tokens=True)
test_tokenizer(tokenizer)
execution result
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
LlamaTokenizerFast(name_or_path='HuggingFaceM4/idefics2-8b', vocab_size=32000, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='left', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '<unk>', 'additional_special_tokens': ['<fake_token_around_image>', '<image>', '<end_of_utterance>']}, clean_up_tokenization_spaces=False), added_tokens_decoder={
0: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
1: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
32000: AddedToken("<fake_token_around_image>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
32001: AddedToken("<image>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
32002: AddedToken("<end_of_utterance>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
=======
!@# {'input_ids': [1, 918, 28818, 28771], 'attention_mask': [1, 1, 1, 1]}
!@# {'input_ids': [1, 918, 28818, 28771, 28705], 'attention_mask': [1, 1, 1, 1, 1]}
!@# <ACTION_1> {'input_ids': [1, 918, 28818, 28771, 28705, 32004], 'attention_mask': [1, 1, 1, 1, 1, 1]}
!@# <ACTION_1> {'input_ids': [1, 918, 28818, 28771, 28705, 32004, 28705], 'attention_mask': [1, 1, 1, 1, 1, 1, 1]}
!@# <ACTION_1> <ACTION_2> {'input_ids': [1, 918, 28818, 28771, 28705, 32004, 28705, 32005], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1]}
!@# <ACTION_1><ACTION_2> {'input_ids': [1, 918, 28818, 28771, 28705, 32004, 32005], 'attention_mask': [1, 1, 1, 1, 1, 1, 1]}
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
=======
!@# {'input_ids': [1, 918, 28818, 28771], 'attention_mask': [1, 1, 1, 1]}
!@# {'input_ids': [1, 918, 28818, 28771, 28705], 'attention_mask': [1, 1, 1, 1, 1]}
!@# <ACTION_1> {'input_ids': [1, 918, 28818, 28771, 28705, 32004], 'attention_mask': [1, 1, 1, 1, 1, 1]}
!@# <ACTION_1> {'input_ids': [1, 918, 28818, 28771, 28705, 32004, 28705], 'attention_mask': [1, 1, 1, 1, 1, 1, 1]}
!@# <ACTION_1> <ACTION_2> {'input_ids': [1, 918, 28818, 28771, 28705, 32004, 28705, 32005], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1]}
!@# <ACTION_1><ACTION_2> {'input_ids': [1, 918, 28818, 28771, 28705, 32004, 32005], 'attention_mask': [1, 1, 1, 1, 1, 1, 1]}
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
=======
!@# {'input_ids': [1, 918, 28818, 28771], 'attention_mask': [1, 1, 1, 1]}
!@# {'input_ids': [1, 918, 28818, 28771, 28705], 'attention_mask': [1, 1, 1, 1, 1]}
!@# <ACTION_1> {'input_ids': [1, 918, 28818, 28771, 32004], 'attention_mask': [1, 1, 1, 1, 1]}
!@# <ACTION_1> {'input_ids': [1, 918, 28818, 28771, 32004, 28705], 'attention_mask': [1, 1, 1, 1, 1, 1]}
!@# <ACTION_1> <ACTION_2> {'input_ids': [1, 918, 28818, 28771, 32004, 32005], 'attention_mask': [1, 1, 1, 1, 1, 1]}
!@# <ACTION_1><ACTION_2> {'input_ids': [1, 918, 28818, 28771, 32004, 28789, 17615, 28730, 28750, 28767], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
=======
!@# {'input_ids': [1, 918, 28818, 28771], 'attention_mask': [1, 1, 1, 1]}
!@# {'input_ids': [1, 918, 28818, 28771, 28705], 'attention_mask': [1, 1, 1, 1, 1]}
!@# <ACTION_1> {'input_ids': [1, 918, 28818, 28771, 28705, 32004], 'attention_mask': [1, 1, 1, 1, 1, 1]}
!@# <ACTION_1> {'input_ids': [1, 918, 28818, 28771, 28705, 32004, 259], 'attention_mask': [1, 1, 1, 1, 1, 1, 1]}
!@# <ACTION_1> <ACTION_2> {'input_ids': [1, 918, 28818, 28771, 28705, 32004, 259, 32005], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1]}
!@# <ACTION_1><ACTION_2> {'input_ids': [1, 918, 28818, 28771, 28705, 32004, 32005], 'attention_mask': [1, 1, 1, 1, 1, 1, 1]}
Additional Note
If I use from_slow
option to load Fast Tokenizer, it have no problem.
tokenizer = LlamaTokenizerFast.from_pretrained("HuggingFaceM4/idefics2-8b", from_slow=True)