transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Fix handling of Sequence post-processors in train_new_from_iterator

Open taidopurason opened this issue 1 year ago • 0 comments

What does this PR do?

This PR fixes an issue where the post-processor special token IDs are not correctly updated when training a new tokenizer using train_new_from_iterator of a tokenizer with a Sequence post-processor. Instead, the special token IDs are copied directly from the original tokenizer.

For example, this affects training a new tokenizer from Llama-3 tokenizers, as reported in #33998 and #30752.

Running the following code:

from transformers import AutoTokenizer
from datasets import load_dataset
import json
from itertools import islice

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B")
ds = load_dataset("wikimedia/wikipedia", "20231101.et", streaming=True, split="train")

new_tokenizer = tokenizer.train_new_from_iterator([x["text"] for x in islice(ds, 100)], 1000)

print(f"bos_token_id={new_tokenizer.bos_token_id}")
print(f"'Hello world!' tokenized as {new_tokenizer('Hello world!')['input_ids']}")
print(json.dumps(json.loads(new_tokenizer._tokenizer.to_str())['post_processor'], indent=2))

the output is:

bos_token_id=0
'Hello world!' tokenized as [128000, 294, 569, 727, 399, 338, 541, 327, 319, 256]
{
  "type": "Sequence",
  "processors": [
    {
      "type": "ByteLevel",
      "add_prefix_space": true,
      "trim_offsets": false,
      "use_regex": true
    },
    {
      "type": "TemplateProcessing",
      "single": [
        {
          "SpecialToken": {
            "id": "<|begin_of_text|>",
            "type_id": 0
          }
        },
        {
          "Sequence": {
            "id": "A",
            "type_id": 0
          }
        }
      ],
      "pair": [
        {
          "SpecialToken": {
            "id": "<|begin_of_text|>",
            "type_id": 0
          }
        },
        {
          "Sequence": {
            "id": "A",
            "type_id": 0
          }
        },
        {
          "SpecialToken": {
            "id": "<|begin_of_text|>",
            "type_id": 1
          }
        },
        {
          "Sequence": {
            "id": "B",
            "type_id": 1
          }
        }
      ],
      "special_tokens": {
        "<|begin_of_text|>": {
          "id": "<|begin_of_text|>",
          "ids": [
            128000
          ],
          "tokens": [
            "<|begin_of_text|>"
          ]
        }
      }
    }
  ]
}

As shown, the new tokenizer prepends an incorrect bos_token_id (128000 instead of 0)

Fixes #33998 #30752

I welcome feedback and suggestions on this fix.

Before submitting

  • [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • [x] Did you read the contributor guideline, Pull Request section?
  • [x] Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
  • [ ] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
  • [ ] Did you write any new necessary tests?

Who can review?

  • tokenizers: @ArthurZucker

taidopurason avatar Oct 18 '24 10:10 taidopurason