Help with ModernBert tokenizer (BPE+special tokens)
@chengchingwen, would you mind sharing some pointers on implementing ModernBert tokenizer?
I've tried it both from scratch and with your package and I can't get either to work -- example below.
Source: https://huggingface.co/answerdotai/ModernBERT-base/blob/main/tokenizer_config.json My attempt: https://github.com/svilupp/ModernBert.jl
I tried with your package to load the merges and mimic how you use it in some of the tests/examples: https://github.com/svilupp/ModernBert.jl/blob/01819a6a762eb0d6ff8ca0c63e3d0418b9a48ce9/src/bytepair.jl#L164
Failing examples: https://github.com/svilupp/ModernBert.jl/blob/01819a6a762eb0d6ff8ca0c63e3d0418b9a48ce9/examples/verify.jl#L34
text = "The capital of France is [MASK]."
tokens = tokenize(tokenizer, text)
@test tokens ==
["[CLS]", "The", "Ġcapital", "Ġof", "ĠFrance", "Ġis", " [MASK]", ".", "[SEP]"
I struggle with catching the special token [MASK].
Test Failed at /Users/simljx/Documents/GitHub/ModernBert.jl/examples/verify.jl:34
Expression: tokens2 == ["[CLS]", "The", "Ġcapital", "Ġof", "ĠFrance", "Ġis", " [MASK]", ".", "[SEP]"]
Evaluated: ["[CLS]", "The", "Ġcapital", "Ġof", "ĠFrance", "Ġis", "Ġ", **"[MASK]",** ".", "[SEP]"] == ["[CLS]", "The", "Ġcapital", "Ġof", "ĠFrance", "Ġis", **" [MASK]",** ".", "[SEP]"]
I tried using MatchTokenizer and adding the special tokens. I also tried introducing my own tokenizer (MaskTokenizer) to manually fix it, but it's too high up in the stack -- it makes no difference:
# Create tokenizer pipeline
base_tokenizer = BPE(bpe_merges)
tokenizer = BPETokenizer(
TextEncodeBase.MatchTokenization(
MaskTokenization(
CodeNormalizer(
BPETokenization(
GPT2Tokenization(),
base_tokenizer
),
gpt2_codemap()
),
"[MASK]"),
collect(keys(special_tokens))
)
)
Without match tokenization, it splits up the special tokens, eg,
"Ġ[", "MASK", "]."
Tokenizer setting:
"normalizer": { "type": "NFC" }, "pre_tokenizer": { "type": "ByteLevel", "add_prefix_space": false, "trim_offsets": true, "use_regex": true },
Would you have any pointers on where to start? I'm not sure what else to start
Did you check and compare the result of Transformers.HuggingFace.load_tokenizer("answerdotai/ModernBERT-base")?
yes, that's what I'm comparing against
Could you elaborate more on your attempt and what failed?
As per the test shown above, I attempt to tokenize text containing mask token.
If I use vanilla BPE from your library, it splits it fully (separate "[" etc)
If I use MatchTokenization, it keeps it together but adds a separate "G" boundary token to it (see the test result above).
I'm not sure how to implement the correct behavior for matched/special tokens // I mean the " [MASK]" token that python implementation returns.
Previously, I implemented BPE from scratch, but I had problems with tokenizing complicated unicode characters correctly, which is why I switched to your library.
MatchTokenization should return the exact match without extra space (that is actually something achievable with huggingface/transformers' tokenizer but not implemented here). I couldn't reproduce the issue, given:
_tkr = HuggingFace.load_tokenizer("answerdotai/ModernBERT-base");
julia> _bpe = _tkr.tokenizer.tokenization.base.base.base.bpe.bpe
BPE(50009 merges)
julia> _bpetkr = BPETokenizer(
MatchTokenization(
CodeNormalizer(
BPETokenization(
GPT2Tokenization(),
_bpe
),
gpt2_codemap()
),
_tkr.tokenizer.tokenization.patterns
)
)
julia> _bpetkr(TextEncoders.Sentence("The captial of France is [MASK]"))
8-element Vector{String}:
"The"
"Ġcapt"
"ial"
"Ġof"
"ĠFrance"
"Ġis"
"Ġ"
"[MASK]"
@svilupp The hardcoded result in the test has an extra space. Is that expected?
@svilupp The hardcoded result in the test has an extra space. Is that expected?
Yes, in Python, the space is folded into the special token.
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch
# I use large model but the tokenizer is the same
model_id = "answerdotai/ModernBERT-large"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForMaskedLM.from_pretrained(model_id)
text = "The capital of France is [MASK]."
inputs = tokenizer(text, return_tensors="pt")
print("Tokens: ", tokenizer.tokenize(text))
print(inputs["input_ids"])
Tokens: ['The', 'Ġcapital', 'Ġof', 'ĠFrance', 'Ġis', ' [MASK]', '.']
tensor([[50281, 510, 5347, 273, 6181, 310, 50284, 15, 50282]])
Notice that Gis / 310 is the only thing preceding the MASK token --> the space is somehow folded into the special token.
Ah, I see. So it uses exactly the same unimplemented feature I mentioned above. Let me check the spec of that behavior.
So that extra space part is actually quite simple. You can just replace the special_token string with Regex(raw"\s*" * Base.wrap_string(special_token, UInt32(0))), which should return something like r"\s*\Q[MASK]\E", when you extracting the special tokens with lstrip = true. MatchTokenization can take a list of regexes and prevent the match from subsequent splits.
As a note for future updates:
lstrip = truematch withr"\s*\Q<special>\E".rstrip = truematch withr"\Q<special>\E\s*".normalized = trueneed twoMatchTokenization, one before and one after normalization.single_word = truefilter the matches that hasr"\w"before and after the match.
Thanks!
I then had a problem with getting the encoding working (not getting the unknown token with -1, but I just added the variation in the DictBackedLookupDict. It's not glamorous, but it should do!
You need to specify the unk token with the vocabulary, something like Vocab(strings, unk_token).
I know, but I was referring to matching " [MASK]" as unknown!
Hmm, so python would return the same index for all " [MASK]" variants?
That is my understanding:
text = "[MASK][MASK] is a [MASK]"
inputs = tokenizer(text, return_tensors="pt")
>>> print("Tokens: ", tokenizer.tokenize(text))
Tokens: ['[MASK]', '[MASK]', 'Ġis', 'Ġa', ' [MASK]']
>>> print(inputs["input_ids"])
tensor([[50281, 50284, 50284, 310, 247, 50284, 50282]])
My current battle is with multi-whitespaces.
- I should be getting triple whitespace
>>> text= " "
>>> inputs = tokenizer(text, return_tensors="pt")
>>> print("Tokens: ", tokenizer.tokenize(text))
Tokens: [' ']
>>> print(inputs["input_ids"])
tensor([[50281, 50275, 50282]])
But I get "ĠĠĠ" (token 341) instead.
- I should be getting no word boundary for "spaces" token (last one), but I get
Ġspaces(notice that the last word boundary is Ġuse)
>>> text = "Mr. O'Neill-McPherson's co-workers @ ABC.com [and] {Dr. J.R.R. Martin-Smith} use multiple spaces!"
>>> inputs = tokenizer(text, return_tensors="pt")
>>> print("Tokens: ", tokenizer.tokenize(text))
Tokens: ['Mr', '.', 'ĠO', "'", 'Neill', '-', 'Mc', 'P', 'her', 'son', "'s", 'Ġco', '-', 'workers', 'Ġ@', 'ĠABC', '.', 'com', 'Ġ[', 'and', ']', 'Ġ{', 'Dr', '.', 'ĠJ', '.', 'R', '.', 'R', '.', 'ĠMartin', '-', 'Smith', '}', 'Ġuse', ' ', 'multiple', ' ', 'spaces', '!']
>>> print(inputs["input_ids"])
tensor([[50281, 7710, 15, 473, 8, 41437, 14, 11773, 49, 379,
1665, 434, 820, 14, 26719, 1214, 15599, 15, 681, 544,
395, 62, 551, 9034, 15, 500, 15, 51, 15, 51,
15, 8698, 14, 21484, 94, 897, 50273, 34263, 50273, 31748,
2, 50282]])
Any clues about what's happening? I suspect the regex rule splitter
Can you check the patterns in MatchTokenization? Both " " and " " (3 and 5 spaces) should be in the special token lists.