keras-nlp icon indicating copy to clipboard operation
keras-nlp copied to clipboard

Automatic compilation and better compilation defaults for task models

Open jbischof opened this issue 2 years ago • 1 comments

Our task models such as BertClassifier are intended to work out-of-the-box, but currently users need considerable domain knowledge or an example to compile them correctly. For example, our "Getting Started" guide is full of clunky snippets like

classifier.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=keras.optimizers.AdamW(5e-5),
    metrics=keras.metrics.SparseCategoricalAccuracy(),
    jit_compile=True,
)

Here are two changes that can improve the novice user experience:

  1. Overwrite compile to add reasonable defaults for these args
  2. Compile task models automatically in __init__ using the default args

This would allow even AdamW with fancy learning rate schedulers to be the default behavior. Additionally, if users want to overwrite one arg in compile they will not lose our settings for others (e.g., classifier.compile(jit_compile=False) for TPU).

Note: we cannot do this for backbone models since they do not have a loss function.

Steps

  • [x] Pilot default compilation in BERT and RoBERTa classifiers (#695)
  • [ ] Extend default compilation to other classifier task models (#709)
  • [ ] Extend default compilation to language model tasks
  • [ ] Pilot better compiler defaults for BERT and RoBERTa
  • [ ] Extend better compiler defaults to other classifier tasks

jbischof avatar Dec 28 '22 17:12 jbischof

We likely want to compile with Adam rather than AdamW so we don't throw on older versions of TF.

jbischof avatar Jan 11 '23 04:01 jbischof

Fixed!

mattdangerw avatar Aug 14 '24 02:08 mattdangerw