keras-nlp icon indicating copy to clipboard operation
keras-nlp copied to clipboard

Allow instantiating a BERT model in mixed precision from pre-trained weights

Open mattdangerw opened this issue 3 years ago • 2 comments

Our BERT models should support mixed precision as described in https://keras.io/api/mixed_precision/.

The models should follow the global policy keras.mixed_precision.set_global_policy() when set.

The models should probably also also expose a dtype argument. E.g. keras_nlp.models.BertBase(dtype="mixed_float16").

This should also work when specifying a pretrained checkpoint to load.

mattdangerw avatar Aug 29 '22 20:08 mattdangerw

This will need to be a little exploratory, I don't think anyone has looked into this yet! It's definitely possible things will already mostly work today; I don't know the details of checkpoint restoration with mixed precision.

mattdangerw avatar Aug 29 '22 20:08 mattdangerw

Did some preliminary investigation here: https://colab.research.google.com/drive/1nyJmQq9PEHkY2OouSAM3pOqQsM9CG69t?usp=sharing.

  1. keras.mixed_precision.set_global_policy("mixed_float16"): Works. I checked the compute_dtype and dtype of every layer, which are float16 and float32, respectively. The output is float16, as expected.

  2. keras.Model has dtype as part of its kwargs: https://github.com/keras-team/keras/blob/86ed065f3b5cc86fcc8910fe7abf0ab3a1f422a9/keras/engine/training.py#L264. I tried passing float64 to BertCustom, but it did not work. So, in order to make sure dtype is propagated through all the layers, we will have to pass dtype=dtype in all layers of BertCustom.

Regarding (2), had a discussion with @mattdangerw, seems like this will require a decent amount of work in core Keras (if we want automatic propagation of dtype to the model's layers).

abheesht17 avatar Sep 14 '22 12:09 abheesht17