minbpe icon indicating copy to clipboard operation
minbpe copied to clipboard

A probably faster way for training the tokenizer (pure Python)

Open ReinforcedKnowledge opened this issue 4 months ago • 3 comments

Hi!

I'm not sure if this is the appropriate place for posting this, I'm sorry if it is not.

I think there is a way to make the training of the tokenizer faster.

Where does the initial code get slow

I think it's mainly at https://github.com/karpathy/minbpe/blob/f50ad93aa65072ac23e0a7e0bbc64c9c4e26cc4a/minbpe/basic.py#L33 and at https://github.com/karpathy/minbpe/blob/f50ad93aa65072ac23e0a7e0bbc64c9c4e26cc4a/minbpe/basic.py#L39

I think since we're doing whole passes over all the dictionaries and returning new instances of them is quite costly.

Possible solution

I think it's possible to do some kind of one pass to find the best pair, then update everything in-place. I'll try to explain my code, unfortunately I can't check for the moment how to integrate it as an example or something in your own train method, but I'll explain at best my inputs and outputs and the code. I think my train_dict is the equivalent of vocab in your code, and my pairs_dict is the equivalent of stats. Also, the repository's implementation is byte-level BPE while I'm going to talk about standard BPE, but we can go from one to the other easily.

I'm going with the same corpus as the Huggingface summary about tokenizers: https://huggingface.co/docs/transformers/en/tokenizer_summary : corpus = "hug " * 10 + "pug " * 5 + "pun " * 12 + "bun " * 4 + "hugs " * 5.

I start with a train_dict, which will contain the words' frequencies, of the following format:

{0: [10, [('h', 'u'), ('u', 'g'), ('g', 'Ġ')]],
 1: [5, [('p', 'u'), ('u', 'g'), ('g', 'Ġ')]],
 2: [12, [('p', 'u'), ('u', 'n'), ('n', 'Ġ')]],
 3: [4, [('b', 'u'), ('u', 'n'), ('n', 'Ġ')]],
 4: [5, [('h', 'u'), ('u', 'g'), ('g', 's'), ('s', 'Ġ')]]}

The structure is the following: word id: List[word frequency, List[word symbols pairs]]. The idea is, when we find the best pair that leads to the merge rule, the find all the words that contribute to that pair, and update them after the merge accordingly. The frequency is useful because we want to keep track of the pairs' frequencies (which are a sum of the frequencies of the word that contribute to them)

And I also start with the following pairs_dict which is a dictionary that associates pairs to their frequencies and the words' ids that contribute to the pairs. In our case it is:

{('h', 'u'): [15, {0, 4}],
 ('u', 'g'): [20, {0, 1, 4}],
 ('g', 'Ġ'): [15, {0, 1}],
 ('p', 'u'): [17, {1, 2}],
 ('u', 'n'): [16, {2, 3}],
 ('n', 'Ġ'): [16, {2, 3}],
 ('b', 'u'): [4, {3}],
 ('g', 's'): [5, {4}],
 ('s', 'Ġ'): [5, {4}]}

My training loop is simple, find the best pair (the merge rule) and then update both the train_dict and the pairs_dict

def train_loop(train_dict: Dict, pairs_dict: Dict, num_merges: int) -> None:
    for i in range(num_merges):
        best = max(pairs_dict, key=lambda pair: pairs_dict[pair][0])
        vocab["".join(best)] = base_vocab_size + i
        merge_pairs(train_dict, pairs_dict, best)

For the merge_pairs function, I'm not so proud of it but it does the job 😅

def merge_pairs(words_dict, pairs_dict, max_freq_pair):
    max_freq_pair_merged = "".join(max_freq_pair)
    for word_id in words_dict:
        word_freq = words_dict[word_id][0]
        pairs = words_dict[word_id][1]
        new_pairs = []
        i = 0
        while i < len(pairs):
            if pairs[i] == max_freq_pair:
                # Check for preceding pair
                if i > 0 and new_pairs:
                    prev_pair = new_pairs[-1]
                    new_pairs[-1] = (prev_pair[0], max_freq_pair_merged)
                    update_pairs_dict(
                        pairs_dict, prev_pair, -word_freq, word_id
                    )
                    update_pairs_dict(
                        pairs_dict, new_pairs[-1], word_freq, word_id
                    )
                # Check for following pair
                if i < len(pairs) - 1:
                    next_pair = (max_freq_pair_merged, pairs[i + 1][1])
                    new_pairs.append(next_pair)
                    update_pairs_dict(
                        pairs_dict, pairs[i + 1], -word_freq, word_id
                    )
                    update_pairs_dict(
                        pairs_dict, next_pair, word_freq, word_id
                    )
                    i += 1  # Skip the next pair as it's now merged
            else:
                new_pairs.append(pairs[i])
            i += 1
        words_dict[word_id][1] = new_pairs

    # Delete max_freq_pair from pairs_dict
    del pairs_dict[max_freq_pair]

As you can see I do only one pass through the training dictionary and an incomplete pass through the pairs dictionary instead of having two passes, one through the training dictionary and one through the pairs. And the updates are done in-place.

The updates of the pairs_dict are based on the frequency of the words that participated to the pair.

def update_pairs_dict(pairs_dict, pair, freq_change, word_id):
    if pair in pairs_dict:
        pairs_dict[pair][0] += freq_change
        if freq_change > 0:  # If we are adding frequency, add the word ID
            pairs_dict[pair][1].add(word_id)
        if (
            pairs_dict[pair][0] <= 0
        ):  # If frequency is zero or less, delete the pair
            del pairs_dict[pair]
    else:
        pairs_dict[pair] = (
            [freq_change, {word_id}] if freq_change > 0 else [freq_change, set()]
        )

Context

I know this repository is for educational purposes only and clearly the code in the repo is very concise and clear.

I also know that the focus on making tokenizers faster is on the tokenization part instead of training. Especially since there is Rust code that is way more performant at doing that, and in a distributed fashion as well.

But, I thought maybe someone could benefit from this. I'm trying to implement a bunch of stuff for myself to learn and I have learned a lot by trying to improve my code for better efficiency. I'm not saying my code is perfect, there surely is a way to use some great data structures for this, such linked lists or graphs. I just couldn't do it yet. And I'm aware of the limits of Python compared to Rust (and some other programming languages, but I think Rust is the one most used for tokenizers) for doing distributed compute or for any memory efficient application for the matter. I just thought it'd be great to push the limit on what we can do with tokenizers with pure Python code.

ReinforcedKnowledge avatar Feb 20 '24 21:02 ReinforcedKnowledge