ModernBERT icon indicating copy to clipboard operation
ModernBERT copied to clipboard

How to continue to train ModernBert base/large in a specific domain?

Open BernieVA opened this issue 9 months ago • 2 comments

Hi, thanks for the wonderful work! I was trying to continue training Modernbert in a specific domain. I could not find detailed documentations, would you please help me to answer some of the questions I have. This is what I have done:

  1. I'm using 8-A100 GPUs, and the total domain texts have about 1 billion tokens.
  2. I created conda env based on the environment.yaml and installed "flash_attn==2.6.3" --no-build-isolation.
  3. I'm using the pretraining_documentation branch, the modernbert-base-pretrain.yaml (https://github.com/AnswerDotAI/ModernBERT/tree/pretraining_documentation/yamls/modernbert#:~:text=2%20months%20ago-,modernbert%2Dbase%2Dpretrain.yaml,-Update%20%26%20add%20ModernBERT).
  4. Also, I'm trying to use the checkpoint to continue to train, found here: https://huggingface.co/answerdotai/ModernBERT-base-training-checkpoints/tree/main

My questions are:

  1. My data are in csv file with one row of an article, should I split each article into sentences, or it will be split automatically with max_seq_len of 1024? Since the corpus are not large, I can convert to MDS format and save it locally, and set streaming to False.
  2. How should I split the corpus into training and validation?
  3. Which checkpoint I need to use from the huggingface hub, there are three folders: pretrain, learning-rate-decay, and context-extension?
  4. Where should I specify the checkpoint in the yaml file? is it in the bottom with 'load_path'? Thank you!

BernieVA avatar Apr 02 '25 15:04 BernieVA

Answers from a user:

  • it will be truncated to max_seq_len in your yaml, as long as data are in MDS format
  • random sample 3%-5% and run a deduplication to see generalization capability?
  • I'd use learning-rate-decay, and continue training using extended seq_len as if it's context-extension but with specific domain data
  • yes it's at the bottom with load_path, be aware of learning rate(it's available in stdout logs), there's some re-scaling logic in the codes

ahxxm avatar Apr 10 '25 11:04 ahxxm

Hey is this something like tsdae for domain adaptation?

rnbokade avatar Apr 18 '25 08:04 rnbokade