Question about Encoder Logic
I noticed the encode() method has extra logic with a while loop to find the lowest merge index:
def encode(self, text):
text_bytes = text.encode("utf-8") # raw bytes
ids = list(text_bytes) # list of integers in range 0..255
while len(ids) >= 2:
stats = get_stats(ids)
pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
if pair not in self.merges:
break # nothing else can be merged anymore
idx = self.merges[pair]
ids = merge(ids, pair, idx)
return ids
Can we simplify it like this:
def encode(self, text):
tokens = text.encode("utf-8")
tokens = list(map(int, tokens))
for pair, index in self.merges.items():
tokens = merge(tokens, pair, index)
return tokens
Since merge() merges all occurrences, it seems a simple for loop suffices. Is there a reason for the more complex logic? I have trained my tokenizer vs the basictokenizer on some text data, and achieved the exact same vocab & encoder. Maybe I missed something. Could you clarify?
Thanks!
Update: I made a pytest from my forked repo just to show mine is also correct: For anyone interested to try out
I think it's a little bit different, but the effect should be the same (your python version should be higher than 3.7). Your implementation completely iterates over all merge items, but the original code can jump out. I think the reason the original code was written this way was to prevent the dictionary order might not be in the order it was added. Karpathy seems to mention this in the video, but the issue was fixed in py3.8.
As @202030481266 mentioned, your simpler version iterates over all of the merges made in the vocabulary. For a realistic tokenizer this is a lot of merges (~50k for GPT2, 200k+ for GPT4o) so at a practical scale your approach would require a lot more work and most of the merges applied would not even be in the chunk of text being processed.
But there is a little overly complex thing in Karpathy's code here. He calls get_stats inside of encode but only uses the keys from the get_stats dictionary. Since we're only using the keys here, there's no sense in going through the trouble of calculating the values (which is the point of get_stats). So instead of using get_stats(ids) it would be a lot less work to line up the consecutive pairs like this zip(ids, ids[1:]). Even if the ids list is only one element long that will still work correctly without throwing an out of range error.
I prefer to do this using:
def encode2(text):
tokens = list(text.encode('utf-8'))
# make the sorted merges dict
pairs = sorted(merges, key=merges.get)
for pair in pairs:
tokens = merge(tokens, pair, merges[pair])
return tokens
Hope it will be helpful to you.
I think it's a little bit different, but the effect should be the same (your python version should be higher than 3.7). Your implementation completely iterates over all merge items, but the original code can jump out. I think the reason the original code was written this way was to prevent the dictionary order might not be in the order it was added. Karpathy seems to mention this in the video, but the issue was fixed in py3.8.
The code is logically equivalent but differs in efficiency. Both karpathy's code and encode2 depend on the dict order being the order of insertion
I came up with the same encode2 func. Yes, you are right encode2 is simpler and is logically equivalent, but is slower( by ~ 77% on my randomly generated test cases)
import random
import time
encode_time = 0
encode2_time = 0
test_cases = []
for _ in range(500000):
l = random.randint(0, 30)
s = random.randint(0, len(text))
e = min(len(text)-1, s+l)
test = text[s:e]
t1 = time.time()
e1 = encode(test)
t2 = time.time()
encode_time += (t2-t1)
t1 = time.time()
e2 = encode2(test)
#print(e1, e2)
t2 = time.time()
encode2_time += (t2-t1)
assert e1 == e2
#test_cases.append(text[s:e])
print(f'encode time: {encode_time:5f}s')
print(f'encode2 time: {encode2_time:5f}s')
print(f'encode2/encode ratio: {encode2_time/encode_time:2f}')
encode time: 10.358455s encode2 time: 18.415130s encode2/encode ratio: 1.777787
def encode(text):
""" encode text into tokens"""
# first get the int repr of unicode bytes
ids = list(text.encode('utf-8'))
while len(ids) >= 2:
stats = get_stats(ids)
# the keys in merges are ordered by insertion order
# get the pair from stats that was the earliest to be merged
pair = min(stats, key=lambda p: merges.get(p, float("inf")))
if pair not in merges:
break
ids = merge(ids, pair, merges[pair])
return ids
def encode2(text):
""" naive encode that is slower"""
ids = list(text.encode('utf-8'))
for pair, merged_id in merges.items():
ids = merge(ids, pair, merged_id)
return ids