minbpe icon indicating copy to clipboard operation
minbpe copied to clipboard

Train BasicTokenizer on GPU with PyTorch, 100x speedup

Open kuprel opened this issue 2 years ago • 5 comments

The following files are added:

  • minbpe/torch/base.py
    • Contains merge_torch
  • minbpe/torch/basic.py
    • Contains BasicTokenizerTorch, overrides the train and encode methods of BasicTokenizer
  • minbpe/torch/regex.py
    • Contains RegexTokenizerTorch, overrides the encode_ordinary method of RegexTokenizer
  • minbpe/torch/gpt4.py
    • Contains GPT4TokenizerTorch, mostly inherits from GPT4Tokenizer, but uses RegexTokenizerTorch's encode method
  • train_torch.py
    • Similar to train.py but trains BasicTokenizerTorch

The following files are modified:

  • minbpe/__init__.py
    • Import torch tokenizers
  • tests/test_tokenizer.py
    • Add torch tokenizers to tests

It takes 67.4 seconds on an H100 80GB SXM5 to train the BasicTokenizerTorch with a vocab_size of 512 on 308MB of Enron emails. The original code takes 2hrs 15min on an M2 Air with Python 3.11 to do this.

I'm not sure if RegexTokenizerTorch or GPT4TokenizerTorch can benefit much from pytorch since there are many chunks of varying lengths, i.e. a "ragged tensor". These tokenizers are helpful for sanity checks though. For example, the test_gpt4_tiktoken_equality tests all pass suggesting that merge_torch is correctly implemented.

I also made a new repository minbpe-pytorch in case adding pytorch support is beyond the scope of this project.

kuprel avatar Feb 22 '24 01:02 kuprel

Using an H100 and int16, it's now 108x speedup over the original implementation on M2 air

kuprel avatar Feb 23 '24 20:02 kuprel

All of the tests pass

Screenshot 2024-02-24 at 8 48 33 PM

kuprel avatar Feb 25 '24 04:02 kuprel

Ok I'll step through this soon to take a look. Not sure that I love duplicating everything and creating torch versions of it. Would we be able to potentially isolate the def that is the bottleneck (I'm guessing in base.py), and just surgically have a fast version of one of those defs? If that isn't straight forward happy to link to minbpe-pytorch.

karpathy avatar Feb 27 '24 00:02 karpathy

Thanks for the feedback! I made the diff more surgical. Now the only added files are:

  • minbpe/basic_torch.py
    • Contains merge_torch and BasicTorchTokenizer, overrides the train and encode methods of BasicTokenizer
  • train_torch.py
    • Similar to train.py but trains BasicTorchTokenizer

And the following files are lightly modified:

  • minbpe/__init__.py
    • Import BasicTorchTokenizer
  • tests/test_tokenizer.py
    • Add BasicTorchTokenizer to tests

kuprel avatar Feb 27 '24 06:02 kuprel