transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Inconsistency behaviours between LlamaTokenizer and LlamaTokenizerFast when lstrip=True

Open x54-729 opened this issue 1 year ago • 2 comments

System Info

ubuntu transformers==4.39.0dev

Who can help?

No response

Information

  • [ ] The official example scripts
  • [ ] My own modified scripts

Tasks

  • [ ] An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • [ ] My own task or dataset (give details below)

Reproduction

Refer to tokenization test:

https://github.com/huggingface/transformers/blob/9acce7de1cb8229304a467938ebb47727d60cdb2/tests/test_tokenization_common.py#L864-L881

from transformers import AutoTokenizer
from transformers.tokenization_utils import AddedToken

tokenizer = AutoTokenizer.from_pretrained("/path/to/llama2", use_fast=False)

special_token = tokenizer.all_special_tokens[0]

text = special_token + " aaaaa bbbbbb low cccccccccdddddddd l " + special_token
text2 = special_token + " AAAAA BBBBBB low CCCCCCCCCDDDDDDDD l " + special_token

toks_before_adding = tokenizer.tokenize(text)  # toks before adding new_toks

new_toks = ["aaaaa bbbbbb", "cccccccccdddddddd", "AAAAA BBBBBB", "CCCCCCCCCDDDDDDDD"]

added = tokenizer.add_tokens([AddedToken(tok, lstrip=True, rstrip=True) for tok in new_toks])

toks_after_adding = tokenizer.tokenize(text)
toks_after_adding2 = tokenizer.tokenize(text2)

print(toks_after_adding, toks_after_adding2)

If use_fast=False, the output is:

['<s>', 'aaaaa bbbbbb', '▁low', 'cccccccccdddddddd', '▁l', '▁', '<s>'] ['<s>', 'AAAAA BBBBBB', '▁low', 'CCCCCCCCCDDDDDDDD', '▁l', '▁', '<s>']

Otherwise

['<s>', '▁', '▁aaaaa▁bbbbbb', '▁low', '▁cccccccccdddddddd', '▁l', '▁', '<s>'] ['<s>', '▁', '▁AAAAA▁BBBBBB', '▁low', '▁CCCCCCCCCDDDDDDDD', '▁l', '▁', '<s>']

Expected behavior

Seems that lstrip=True did not take effect. When use_fast=True, the output should be the same as use_fast=False

x54-729 avatar Mar 13 '24 09:03 x54-729

cc @ArthurZucker

amyeroberts avatar Mar 13 '24 10:03 amyeroberts

There is also this inconsistency issue.

Here is my case

from transformers import AutoTokenizer
from transformers.tokenization_utils import AddedToken

model_name_or_path = "mistralai/Mistral-7B-v0.1"
fast_tk = AutoTokenizer.from_pretrained(
  model_name_or_path ,
  use_fast=True,
  padding_side="left",
)
slow_tk = AutoTokenizer.from_pretrained(
  model_name_or_path ,
  use_fast=False,
  padding_side="left",
)
new_toks = ["<|im_start|>", "<|im_end|>"]
added = tokenizer.add_tokens([AddedToken(tok, lstrip=True, rstrip=True) for tok in new_toks])

test_text = "<|im_start|>Who are you?<|im_end|>"
print(fast_tk(test_text))
print(slow_tk(test_text))

For fast token, use_fast=False

{'input_ids': [32006, 11447, 460, 368, 28804, 28789, 28766, 321, 28730, 416, 28766, 28767], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

Otherwise

{'input_ids': [32006, 6526, 460, 368, 28804, 32007], 'attention_mask': [1, 1, 1, 1, 1, 1]}

And result {'input_ids': [32006, 6526, 460, 368, 28804, 32007], 'attention_mask': [1, 1, 1, 1, 1, 1]} is expected.

Trangle avatar Mar 22 '24 06:03 Trangle

You need to add the tokens with normalized=False this is not a bug, has been filed as an issue many many times 😉

ArthurZucker avatar Mar 25 '24 07:03 ArthurZucker

This will also be fixed by #28881

ArthurZucker avatar Mar 25 '24 07:03 ArthurZucker

This will also be fixed by #28881

Thank you so much!

x54-729 avatar Mar 25 '24 11:03 x54-729