[FEA] cannot properly save and load the TF Retrieval model
Bug description
We want to be able to save the entire Two-Tower model and load back to be able to do model.evaluate() and model.predict(). However, we get the following error when we load back the model. To reproduce the following errors, please run the 05-Retrieval-Model.ipynb example. Then save and reload the model with the following scripts.
First save the model after model.fit() step:
model.save('two_tower')
Then when we load back the saved model we get the following error:
reloaded = tf.keras.models.load_model('two_tower')
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Input In [17], in <cell line: 1>()
----> 1 reloaded = tf.keras.models.load_model('two_tower')
File /usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py:67, in filter_traceback.<locals>.error_handler(*args, **kwargs)
65 except Exception as e: # pylint: disable=broad-except
66 filtered_tb = _process_traceback_frames(e.__traceback__)
---> 67 raise e.with_traceback(filtered_tb) from None
68 finally:
69 del filtered_tb
File /usr/local/lib/python3.8/dist-packages/keras/saving/saved_model/load.py:1000, in revive_custom_object(identifier, metadata)
998 return revived_cls._init_from_metadata(metadata) # pylint: disable=protected-access
999 else:
-> 1000 raise ValueError(
1001 f'Unable to restore custom object of type {identifier}. '
1002 f'Please make sure that any custom layers are included in the '
1003 f'`custom_objects` arg when calling `load_model()` and make sure that '
1004 f'all layers implement `get_config` and `from_config`.')
ValueError: Unable to restore custom object of type _tf_keras_metric. Please make sure that any custom layers are included in the `custom_objects` arg when calling `load_model()` and make sure that all layers implement `get_config` and `from_config`.
Expected behavior
Environment details
- Merlin version:
- Platform:
- Python version:
- PyTorch version (GPU?):
- Tensorflow version (GPU?): TF 2.8.0
Using merlin-tensorflow-training:22.05 docker image with the latest main branches pulled.
This is also required in this RMP https://github.com/NVIDIA-Merlin/Merlin/issues/271
It looks like we've got some custom objects (custom metrics in this case) that need to be specified when calling load_model.
custom_objects = {
"RecallAt": mm.RecallAt,
"NDCGAt": mm.NDCGAt,
}
reloaded = tf.keras.models.load_model('two_tower', custom_objects=custom_objects)
We could potentially create a load_model helper that adds all the known custom objects automatically.
@oliverholworthy tested that but it did not work for me.
Update. The saving of TwoTowerModel is now working. (22.08)
However, loading is still not working. Due to an unbound query variable in ItemRetrievalScorer when model is reloaded. And trying with the example in the ticket description should now give a new error with a message similar to this:
AssertionError: Found 1 Python objects that were not bound to checkpointed values, likely due to changes in the Python program. Showing 1 of 1 unmatched objects: [<tf.Variable 'query:0' shape=(None, 4) dtype=float32, numpy=array([[0., 0., 0., 0.]], dtype=float32)>]
There is work on-going in #633 that is taking us toward a place where we can replace this retrieval scorer with a new implementation that won't have this issue. Aiming for the next release 22.09
PR #790 introduces the definition of TwoTowerModelV2 that can be saved and loaded correctly