Allow instantiating a BERT model in mixed precision from pre-trained weights
Our BERT models should support mixed precision as described in https://keras.io/api/mixed_precision/.
The models should follow the global policy keras.mixed_precision.set_global_policy() when set.
The models should probably also also expose a dtype argument. E.g. keras_nlp.models.BertBase(dtype="mixed_float16").
This should also work when specifying a pretrained checkpoint to load.
This will need to be a little exploratory, I don't think anyone has looked into this yet! It's definitely possible things will already mostly work today; I don't know the details of checkpoint restoration with mixed precision.
Did some preliminary investigation here: https://colab.research.google.com/drive/1nyJmQq9PEHkY2OouSAM3pOqQsM9CG69t?usp=sharing.
-
keras.mixed_precision.set_global_policy("mixed_float16"): Works. I checked thecompute_dtypeanddtypeof every layer, which arefloat16andfloat32, respectively. The output isfloat16, as expected. -
keras.Modelhasdtypeas part of itskwargs: https://github.com/keras-team/keras/blob/86ed065f3b5cc86fcc8910fe7abf0ab3a1f422a9/keras/engine/training.py#L264. I tried passingfloat64toBertCustom, but it did not work. So, in order to make suredtypeis propagated through all the layers, we will have to passdtype=dtypein all layers ofBertCustom.
Regarding (2), had a discussion with @mattdangerw, seems like this will require a decent amount of work in core Keras (if we want automatic propagation of dtype to the model's layers).