keras-nlp
keras-nlp copied to clipboard
Don't override train_step in bert_train.py
We should be able to rework the input data for bert pretraining, to use model.fit() directly without a custom train_step.
- We will need to rework our input data into a
(features, labels, label_weights)tuple. - Labels should have both
"masked_lm_ids"and"next_sentence_labels". - Weight should be used for
"masked_lm_weights"(and I guess we need to passtf.onesfor next sentence labels?) - There will need to be two losses, one for the next sentence prediction, and one for the mlm loss, with equal weight.
- We should keep an accuracy metric for each loss.