keras-nlp icon indicating copy to clipboard operation
keras-nlp copied to clipboard

Don't override train_step in bert_train.py

Open mattdangerw opened this issue 3 years ago • 0 comments

We should be able to rework the input data for bert pretraining, to use model.fit() directly without a custom train_step.

  • We will need to rework our input data into a (features, labels, label_weights) tuple.
  • Labels should have both "masked_lm_ids" and "next_sentence_labels".
  • Weight should be used for "masked_lm_weights" (and I guess we need to pass tf.ones for next sentence labels?)
  • There will need to be two losses, one for the next sentence prediction, and one for the mlm loss, with equal weight.
  • We should keep an accuracy metric for each loss.

mattdangerw avatar Sep 07 '22 03:09 mattdangerw