keras icon indicating copy to clipboard operation
keras copied to clipboard

Custom objects support when pickling keras models

Open mthiboust opened this issue 2 months ago • 3 comments

Following the recent work on pickle support for Keras 3 (PRs https://github.com/keras-team/keras/pull/19555 & https://github.com/keras-team/keras/pull/19592), it would be useful to add the definition of custom objects directly inside the pickled file. In practice, it allows to decouple the training from the inference with a generic inference code (no need to have the custom object definition available in the environment where the model is loaded).

For now, I am toying with this monkey patch:

import io

import cloudpickle as pickle
from keras.utils import get_custom_objects
from keras.src.saving.keras_saveable import KerasSaveable


# Monkey patch the KerasSaveable class to save custom keras objects alongwith the model


def new_unpickle_model(cls, model_buf, custom_objects_buf):
    import keras.src.saving.saving_lib as saving_lib

    # pickle is not safe regardless of what you do.
    custom_objects = pickle.load(custom_objects_buf)
    return saving_lib._load_model_from_fileobj(
        model_buf, custom_objects=custom_objects, compile=True, safe_mode=False
    )


def new__reduce__(self):
    """__reduce__ is used to customize the behavior of `pickle.pickle()`.

    The method returns a tuple of two elements: a function, and a list of
    arguments to pass to that function.  In this case we just leverage the
    keras saving library."""
    import keras.src.saving.saving_lib as saving_lib

    model_buf = io.BytesIO()
    saving_lib._save_model_to_fileobj(self, model_buf, "h5")

    custom_objects_buf = io.BytesIO()
    pickle.dump(get_custom_objects(), custom_objects_buf)
    custom_objects_buf.seek(0)

    return (
        self._unpickle_model,
        (model_buf, custom_objects_buf),
    )


KerasSaveable._unpickle_model = classmethod(new_unpickle_model)
KerasSaveable.__reduce__ = new__reduce__

Is it the right way to do it? Would you welcome such a change to be directly integrated into https://github.com/keras-team/keras/blob/master/keras/src/saving/keras_saveable.py ?

I know that pickle is not recommended for security reasons when you load a pickled model from an untrusted source. But there are cases where the same entity both dumps and loads the models.

Tagging @LukeWood who may have an opinion on this.

mthiboust avatar Jun 10 '24 23:06 mthiboust