Train BasicTokenizer on GPU with PyTorch, 100x speedup
The following files are added:
- minbpe/torch/base.py
- Contains
merge_torch
- Contains
- minbpe/torch/basic.py
- Contains
BasicTokenizerTorch, overrides thetrainandencodemethods ofBasicTokenizer
- Contains
- minbpe/torch/regex.py
- Contains
RegexTokenizerTorch, overrides theencode_ordinarymethod ofRegexTokenizer
- Contains
- minbpe/torch/gpt4.py
- Contains
GPT4TokenizerTorch, mostly inherits fromGPT4Tokenizer, but usesRegexTokenizerTorch'sencodemethod
- Contains
- train_torch.py
- Similar to train.py but trains
BasicTokenizerTorch
- Similar to train.py but trains
The following files are modified:
- minbpe/__init__.py
- Import torch tokenizers
- tests/test_tokenizer.py
- Add torch tokenizers to tests
It takes 67.4 seconds on an H100 80GB SXM5 to train the BasicTokenizerTorch with a vocab_size of 512 on 308MB of Enron emails. The original code takes 2hrs 15min on an M2 Air with Python 3.11 to do this.
I'm not sure if RegexTokenizerTorch or GPT4TokenizerTorch can benefit much from pytorch since there are many chunks of varying lengths, i.e. a "ragged tensor". These tokenizers are helpful for sanity checks though. For example, the test_gpt4_tiktoken_equality tests all pass suggesting that merge_torch is correctly implemented.
I also made a new repository minbpe-pytorch in case adding pytorch support is beyond the scope of this project.
Using an H100 and int16, it's now 108x speedup over the original implementation on M2 air
All of the tests pass
Ok I'll step through this soon to take a look. Not sure that I love duplicating everything and creating torch versions of it. Would we be able to potentially isolate the def that is the bottleneck (I'm guessing in base.py), and just surgically have a fast version of one of those defs? If that isn't straight forward happy to link to minbpe-pytorch.
Thanks for the feedback! I made the diff more surgical. Now the only added files are:
- minbpe/basic_torch.py
- Contains
merge_torchandBasicTorchTokenizer, overrides thetrainandencodemethods ofBasicTokenizer
- Contains
- train_torch.py
- Similar to train.py but trains
BasicTorchTokenizer
- Similar to train.py but trains
And the following files are lightly modified:
- minbpe/__init__.py
- Import
BasicTorchTokenizer
- Import
- tests/test_tokenizer.py
- Add
BasicTorchTokenizerto tests
- Add