Training Reproducibility
Thank you for putting together this awesome model!
I'm currently evaluating it in various benchmarks against other models (e.g., deberta-v3) using AutoModelForSequenceClassification for fine-tuning, and I'm having an issue with model reproducibility. Specifically, using existing code that allows for finding optimal hyperparameters and then subsequently refitting the final model, I am unable to make my ModernBERT model reproducible using coding that works fine for other model types (e.g., deberta-v3, roberta).
The process of refitting the model uses the huggingface recommended model_init approach for reproducibility:
final_trainer = Trainer(
args = final_args,
data_collator = data_collator,
model_init = model_init,
train_dataset = dataset_dict_tokenized['train'],
eval_dataset = dataset_dict_tokenized['val'],
compute_metrics = compute_metrics
)
However, while this works for all other models, it doesn't seem to for ModernBERT. Initially I thought this might be a flash-attn related issue, but it happens regardless of whether I use that or not.
Any advice here is much appreciated!
Ah, I solved my own issue. Unlike other models, ModernBERT must be loaded with the torch_dtype set, e.g.,
def model_init ():
return AutoModelForSequenceClassification.from_pretrained(
pretrained_model_name_or_path,
num_labels = 2,
torch_dtype = torch.bfloat16 # must be set for reproducibility
)
I noticed a warning about this from flash-attn that helped diagnose the issue -- sorry for the hasty issue post. This is resolved!
In their model README, they don't load their model with torch_dtype = torch.bfloat16. Presumably, that would cause it to be loaded in full bfloat16? Their paper doesn't mention that the model is a full bfloat16 model either.
The README does, however, load the fill-mask pipeline with torch_dtype=torch.bfloat16 but perhaps that is for AMP?
I went with that for the same reason you pointed out (i.e., the readme) as well as a message from flash-attn during model load that suggested the same. That said, if there’s a better way, it would certainly be good to know.
I just checked and it does seem like the pipeline would be loading ModernBERT in full bfloat16, which is inconsistent with the previous example 😆
Now, its possible they're advising full bfloat16 only for inference. I've raised an issue on their model page seeking clarity over this https://huggingface.co/answerdotai/ModernBERT-base/discussions/7
Would you mind reopening this issue until we get a conclusive answer.
RE the flash-attn warning, AFAIK that message will appear for any model. You can't use flash-attn without float16 or bfloat16. But you have the option of AMP or full half precision. The question is what was used to train ModernBERT.
And unfortunately the answer can in fact have an impact on reproducability and accuracy to some extent.
My guess is that it is AMP and not full bfloat16 because AFAIK it is rare to see full bfloat16 used in training. AMP is generally used and full bfloat16 is considered unstable for training.
Good call and thanks for doing the diligence: reopened!
@umarbutler @davedgd Was looking at the issues to check the yamls they are using for training reproducibility. precision: amp_bf16 Do check out this particular config which is used for training