ModernBERT icon indicating copy to clipboard operation
ModernBERT copied to clipboard

On the performance of token classification

Open stefan-it opened this issue 1 year ago • 16 comments

Hi guys,

I spent some time evaluating ModernBERT Large on the standard CoNLL-2003 shared task.

For fine-tuning I used the following code to fine-tune:

python3 run_ner.py \
  --model_name_or_path answerdotai/ModernBERT-large \
  --dataset_name conll2003 \
  --output_dir ./output \
  --do_train \
  --do_eval \
  --do_predict \
  --seed 1 \
  --num_train_epochs 10 \
  --eval_strategy epoch \
  --save_strategy epoch \
  --load_best_model_at_end \
  --per_device_train_batch_size 4 \
  --learning_rate 5e-06 \
  --trust_remote_code=True

Unfortunately, I was only able to 95.20% on development set and 90.54% on test dataset.

Additionally, I performed fine-tuning with Flair library - that does not use the ModernBertForTokenClassification implementation and it also yielded about 95% on development set and ~90% on test dataset.

Now the question is: is there a kind of bug in the current released Transformers version (I tested with https://github.com/huggingface/transformers/commit/b5a557e5fe2d015bd36214a95878370eaed51571) or are there more tricks needed to get token classification working :thinking:

Many thanks in advance!

stefan-it avatar Dec 21 '24 13:12 stefan-it

Hi there!

As a heads-up, I think we’ll be a bit slow at looking deeper into issues over the next two weeks (myself included) as we scatter for the holidays. Sorry about that!

An immediate reason I could see for this in the meantime is potentially the learning rate? In our experience (as hinted in the paper via the appendix tables, but perhaps not strongly enough) , ModernBERT benefits greatly from higher LRs. Our best results were always achieved with LRs in the 3e-05 to 1e-04 range, depending on the task. 5e-06 is probably a good amount too low. Hopefully this helps!

(edit: I’m not terribly familiar with NER training, but could the small batch size also have an impact? For retrieval tasks for example, very small batches are definitely detrimental.)

bclavie avatar Dec 21 '24 15:12 bclavie

Hello,

As mentioned by Benjamin, maybe the LR could be at fault. However, there seems to be another mention of poor results for token classification. As I responded in this thread, our tokenizer is somewhat specific as most of the tokens starts with whitespace, which cause some issue for mask prediction and pushed us to propose a fix, as described here.

I am not sure this is the reason of your results, but it could be?

NohTow avatar Dec 22 '24 16:12 NohTow

Hey @bclavie and @NohTow ,

many thanks for your responses! I experimented with different learning rates (see: https://github.com/stefan-it/modern-bert-ner), I only changed the performance a bit into the 96% range).

W.r.t. the tokenizer issue, what I debugged so far, the token classification example uses this tokenizer call (with is_split_into_words option):

https://github.com/huggingface/transformers/blob/8f38f58f3de5a35f9b8505e9b48985dce5470985/examples/pytorch/token-classification/run_ner.py#L442-#L450

Demo:

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-large")
tokenizer(["Today", "it", "is", "snowing", "in", "Munich", "."], is_split_into_words=True)

# Outputs: {'input_ids': [50281, 14569, 262, 261, 84, 2666, 272, 249, 46, 328, 469, 15, 50282], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

tokenizer.convert_ids_to_tokens([50281, 14569, 262, 261, 84, 2666, 272, 249, 46, 328, 469, 15, 50282])

# Outputs: ['[CLS]', 'Today', 'it', 'is', 's', 'now', 'ing', 'in', 'M', 'un', 'ich', '.', '[SEP]']

Output is highly interesting, compared to RoBERTa output:

from transformers import AutoTokenizer

# add_prefix_space=True is needed because it is slow tokenizer
tokenizer = AutoTokenizer.from_pretrained("FacebookAI/roberta-base", add_prefix_space=True)
tokenizer(["Today", "it", "is", "snowing", "in", "Munich", "."], is_split_into_words=True)

# Outputs: {'input_ids': [0, 2477, 24, 16, 1958, 154, 11, 10489, 479, 2], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
tokenizer.convert_ids_to_tokens([0, 2477, 24, 16, 1958, 154, 11, 10489, 479, 2])
['<s>', 'ĠToday', 'Ġit', 'Ġis', 'Ġsnow', 'ing', 'Ġin', 'ĠMunich', 'Ġ.', '</s>']

I guess the ModernBERT tokenizer is - as @NohTow pointed out - adding a secret whitespace so that no token in the converted output starts with "Ġ".

However, using a "normal" (not split into words) tokenization looks ok:

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-large")
tokenizer("Today it is snowing in Munich.")

# Outputs: {'input_ids': [50281, 14569, 352, 310, 8762, 272, 275, 32000, 15, 50282], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

tokenizer.convert_ids_to_tokens([50281, 14569, 352, 310, 8762, 272, 275, 32000, 15, 50282])
# Outputs: [CLS]', 'Today', 'Ġit', 'Ġis', 'Ġsnow', 'ing', 'Ġin', 'ĠMunich', '.', '[SEP]'

I will try to find a way now to fix this :)

stefan-it avatar Dec 22 '24 16:12 stefan-it

One quick fix is to load the tokenizer with:

from transformers import RobertaTokenizerFast

tokenizer = RobertaTokenizerFast.from_pretrained("answerdotai/ModernBERT-large", add_prefix_space=True)

With my Flair benchmarks I am now able to get 96.24% on development set and 92.01% on test set. I am currently performing some more runs with different hyper-parameters :)

stefan-it avatar Dec 22 '24 19:12 stefan-it

Great catch!

From the RobertaTokenizer docs ...

When used with is_split_into_words=True, this tokenizer needs to be instantiated with add_prefix_space=True.

ModernBERT implements a BPE tokenizer, and based on my reading of the docs, you're required to set this even for RoBERTa (which is a BPE tokenizer as well). The reason you don't have to set it is because the RobertaTokenizer does it for ya when you use is_split_into_words=True automatically ... see here in the source code

Perhaps we should consider making AutoTokenizer work with ModernBERT the same way???

ohmeow avatar Dec 23 '24 18:12 ohmeow

I think it can be changed in the tokenizer_config.json by adding/modifying:

"tokenizer_class": "RobertaTokenizerFast",
"add_prefix_space": true,

But it needs to be tested carefully first :)

stefan-it avatar Dec 23 '24 18:12 stefan-it

一个快速修复方法是使用以下命令加载标记器:

from transformers import RobertaTokenizerFast

tokenizer = RobertaTokenizerFast.from_pretrained("answerdotai/ModernBERT-large", add_prefix_space=True)

通过我的 Flair 基准测试,我现在能够在开发集上获得 96.24%,在测试集上获得 92.01%。我目前正在使用不同的超参数进行更多运行 :)

Compared to the original BERT version, are there currently any advantages?😯

GithubX-F avatar Dec 31 '24 00:12 GithubX-F

This issue still persists and needs fixing. In universal_ner, for instance, DeBERTa v3 outperforms modernbert with a 20% F1 lead with same hyperparameters.

BramVanroy avatar Nov 25 '25 08:11 BramVanroy

Hey @BramVanroy can you try my forked + modified ModernBERT repo (see identifier here: https://github.com/stefan-it/modern-bert-ner?tab=readme-ov-file#results-iii)

It should not be 20% but around 1 - 2% diff :)

stefan-it avatar Nov 25 '25 16:11 stefan-it

That's great! Such a shame that the model was abandoned after release already.

BramVanroy avatar Nov 26 '25 07:11 BramVanroy

That's great! Such a shame that the model was abandoned after release already.

That's quite not very pleasant to read. I did not answer your specific message (from yesterday!) because I was in Séoul for CIKM then took one week of vacation. We continued supporting this repository as much as we could, I try to answer most of the issues but unfortunately there is only so much time in a day, reminder that this is a research repository (that, although a bit vanilla, is usable because a lot of people trained their own version of the model). We are also supporting the whole ModernBERT ecosystem, not only this repo. For example, please note that for this particular issue, we actually fixed a global issue in transformers with respect to the tokenizer. So normally, it should work fine out of the box in HF when using is_split_into_words=True while calling the tokenizer and also maybe use add_prefix_space=True. We decided not to make it the default because the models haven't been trained this way, but again, this was not an issue on our side but on HF, that we fixed for everyone.

Hearing that we abandoned the model while I am personally doing my best on this repo and shipping quite a few fine-tune/helping as much as I can is a bit hurting.

A little help (as done by @stefan-it when debugging) is always welcome and more productive

NohTow avatar Nov 26 '25 07:11 NohTow

I understand that @NohTow and I apologise for my harsh phrasing. The previous comment that went unanswered was from 2024 so I assumed interaction is low and issues are not fixed in the issue tracker. But, as you mentioned, I understand that within the scope of researcher and the time investment it takes to maintain research project, that bar is high to maintain.

From a user-perspective, however, I hope you also understand where I am coming from. I see your model, am super excited about it because "yay finally a new BERT!", but then performance tumbles. If the solution is "just use the fixed model repo from someone else", then as a user I am left with a feeling of "someone else fixed it, fixes have not been implemented upstream", which is not confidence inspiring. The back story of the fixes that you made within transformers are then less relevant for me as an end user, despite its clear usefulness!

Would it not be better if the fixes of Stefan are integrated in the main model repo? If need be in a v1.1 release? Again, new users will have no idea about these pitfalls and will unfairly just conclude "okay, so ModernBERT is just bad?"

Again, I apologize for my phrasing, I was too quick and too harsh. The model and work behind it is impressive and very needed in the OSS community.

BramVanroy avatar Nov 26 '25 10:11 BramVanroy

All good!

The previous comment that went unanswered was from 2024 so I assumed interaction is low and issues are not fixed in the issue tracker.

Yeah that's on us, we fixed it and communicated the information here and there (for example in the linked issue and through dms), but we should have updated it there, I am pretty sure I did in another issue but should have done it as well anyways.

Would it not be better if the fixes of Stefan are integrated in the main model repo? If need be in a v1.1 release? Again, new users will have no idea about these pitfalls and will unfairly just conclude "okay, so ModernBERT is just bad?"

Well, as I said, we did the fixes required to make it work upstream within transformers (without having to rely on patching the tokenizer), just by specifying the parameters to the tokenizer. Those could indeed be set in the config of the models, but as specified earlier, although it should not be that hurtful, the models have not been trained with this space and we cannot guarantee that there would not be any edge effect. So the decision was made to let the NER community specify the parameter (which can then be set in the config of the fine-tuned model) because the model will be fined-tuned anyways in this case and so the model will adapt to this new tokenization.

FWIW, we also planned to release the new models as well as a v1.5 of those (further trained with this fix) with this param by default. Unfortunately, those models have not been finished (yet?).

Tl;Dr: the fix is already upstream, it's accessible through a param, but we did not set the param as the default for our models because we cannot guarantee it won't have side effect since it's a different setup than the training one.

NohTow avatar Nov 26 '25 11:11 NohTow

Great, thanks for the update!

BramVanroy avatar Nov 26 '25 11:11 BramVanroy

I see this tokenizer discussion is also the reason why in https://arxiv.org/abs/2509.06888 the NER with mmBERT-base gave less good results than XLM-R. It's a pity. It is still unclear to me how this fix by @stefan-it could be applied for the mmBERT models as well. Maybe @orionw knows?

Image

jwijffels avatar Nov 28 '25 14:11 jwijffels

@jwijffels the fix has also been applied to the the mmBERT models from the beginning, so there is nothing to change sadly. The issue was in the pre-training and not in the released models.

If this is a major issue for you, you could continue pre-train with language modeling loss on the correctly tokenized text but that is a non-trivial setup. Otherwise you should definitely test out XLM-R if you're doing NER.

orionw avatar Nov 28 '25 15:11 orionw