keras-nlp
keras-nlp copied to clipboard
Register MLMHead as a serializable Keras object
We would like to be able to annotate keras_nlp.layers.MLMHead with @keras.utils.register_keras_serializable(package="keras_nlp"). Which will allow the python object for the layer to restore after saving.
However doing this naively does not work when sharing an embedding weight with an earlier embedding layer in the model.
We need to figure the correct approach for serializing a shared weight here.