Automatic compilation and better compilation defaults for task models
Our task models such as BertClassifier are intended to work out-of-the-box, but currently users need considerable domain knowledge or an example to compile them correctly. For example, our "Getting Started" guide is full of clunky snippets like
classifier.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=keras.optimizers.AdamW(5e-5),
metrics=keras.metrics.SparseCategoricalAccuracy(),
jit_compile=True,
)
Here are two changes that can improve the novice user experience:
- Overwrite
compileto add reasonable defaults for these args - Compile task models automatically in
__init__using the default args
This would allow even AdamW with fancy learning rate schedulers to be the default behavior. Additionally, if users want to overwrite one arg in compile they will not lose our settings for others (e.g., classifier.compile(jit_compile=False) for TPU).
Note: we cannot do this for backbone models since they do not have a loss function.
Steps
- [x] Pilot default compilation in BERT and RoBERTa classifiers (#695)
- [ ] Extend default compilation to other classifier task models (#709)
- [ ] Extend default compilation to language model tasks
- [ ] Pilot better compiler defaults for BERT and RoBERTa
- [ ] Extend better compiler defaults to other classifier tasks
We likely want to compile with Adam rather than AdamW so we don't throw on older versions of TF.
Fixed!