keras-nlp
keras-nlp copied to clipboard
[Memory optimization] Reduce the memory usage for preset() API with load the weights.
Is your feature request related to a problem? Please describe. This is a request forwarded from @martin-gorner based on the user feedback.
Currently when creating model from preset() API with default weights, Keras will first create the model with random weights, and reset them with predefined weights. This cause a unnecessary memory overhead, and can be improved by not creating the random weights in the first place.
Describe the solution you'd like
With the keras.StatelessScope
, we should be able to achieve the following:
with keras.StatelessScope():
model = keras_nlp.some_backbone.from_preset()
model.load_weights(weight)
## After leaving the scope, the model weights should be properly initialized.
model.xxx ()
Describe alternatives you've considered
Additional context I did some experiment in https://colab.research.google.com/drive/1pEXZPBei_6RMYm7Vo8MDG8BphWSbxjYW?authuser=0#scrollTo=NYeH2sQqkeQx, but hit the error with current preset() API, since it was trying to access/use the uninitialized variable before leaving the scope. The stateless_scope probably need to be added within the preset() function to ensure the proper life cycle.