alibi-detect icon indicating copy to clipboard operation
alibi-detect copied to clipboard

Change tensorflow model format to `SavedModel` to support sub-classed models

Open ascillitoe opened this issue 1 year ago • 4 comments

This PR changes the format we use to serialize TensorFlow models from the old HDF5 to the newer SavedModel format.

Motivation

As well as being the default (and recommended) TensorFlow model format, the SavedModel format has the advantage of supporting serialisation of sub-classed tensorflow models. That is, models constructed by subclassing tf.keras.Model, rather than by using tf.keras.Sequential, tf.keras.models.Model etc.

Limitations

  • The new format improves the handling of custom tensorflow objects (such as layers and models) slightly. If these are not passed to tf.keras.models.load_model via custom_objects (or registered with @tf.keras.utils.register_keras_serializable()) the model will be loaded as a keras.saving.saved_model.load.<model_class>*. This is a rough copy of the original serialized model, that behaves the same wrt inference, but cannot be cloned (which is done in a number of learned detectors such as ClassifierDrift). To load the fully-functional model, all custom objects must be supplied at load time.
  • The HiddenOutput class does not work for subclassed models. Therefore, subclassed models cannot be saved/loaded when layer is specified in the ModelConfig.

*in tensorflow>=2.9. In older versions, loading of the model will fail completely if the custom objects are not provided.

Main changes

  • Changed save_format from 'h5' to 'tf' in save_model and load_model, although stuck with h5 for the legacy save/load functions.
  • Removed support for passing a custom_objects dictionary via config since support for this was very flaky. Custom objects in the dictionary could only realistically be specified as registered object strings ('@mymodel etc). However, this is confusing as tensorflow already has its own @tf.keras.utils.register_keras_serializable() decorator.
  • load_detector now allows arbitrary kwargs, which are passed to tf.keras.models.load_model (or torch.load). This is to be used to provide the custom_objects at load time (see example below).
  • Added subclassed models to CI.

Example

Example notebook demonstrating serialisation of a detector with a subclassed tensorflow model. Observe how the custom objects must be passed to load_detector in order to avoid the error KeyError: 'layers'.

Backwards compatibility

tf.keras.models.load_model automatically detects whether a given model path represents a h5 model or SavedModel. This means we should be backwards compatible, in that we can simply move to saving SavedModel's, but still support loading of legacy h5 models.

TODO's

  • [x] Better document limitations of SavedModel format in docs. ~~- [ ] Add a more involved example of passing custom objects to load_detector.~~ - More challenging than first envisaged; I wanted to demonstrate on the amazon example, where a subclassed ClassifierTF model is used, but saving here is not supported due to the use of tokenize_transformer. This is not related to subclassed models so would like to leave for a follow-up PR.
  • [ ] CHANGELOG.md - see https://github.com/SeldonIO/alibi-detect/pull/628#pullrequestreview-1272835501

Old notes etc

This notebook contains some experiments run to explore limitations wrt to the SavedModel format.

ascillitoe avatar Sep 21 '22 17:09 ascillitoe

Codecov Report

Merging #628 (1dc7b61) into master (c0c5e64) will decrease coverage by 0.03%. The diff coverage is 100.00%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #628      +/-   ##
==========================================
- Coverage   80.35%   80.33%   -0.03%     
==========================================
  Files         137      137              
  Lines        9300     9304       +4     
==========================================
+ Hits         7473     7474       +1     
- Misses       1827     1830       +3     
Flag Coverage Δ
macos-latest-3.9 76.81% <100.00%> (+0.01%) :arrow_up:
ubuntu-latest-3.10 80.22% <100.00%> (+<0.01%) :arrow_up:
ubuntu-latest-3.7 80.12% <100.00%> (+0.01%) :arrow_up:
ubuntu-latest-3.8 80.17% <100.00%> (+0.01%) :arrow_up:
ubuntu-latest-3.9 ?
windows-latest-3.9 76.81% <100.00%> (-0.02%) :arrow_down:

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
alibi_detect/saving/schemas.py 98.77% <ø> (-0.01%) :arrow_down:
alibi_detect/base.py 85.45% <100.00%> (ø)
alibi_detect/saving/_pytorch/loading.py 92.15% <100.00%> (+0.32%) :arrow_up:
alibi_detect/saving/_tensorflow/loading.py 85.44% <100.00%> (+0.23%) :arrow_up:
alibi_detect/saving/_tensorflow/saving.py 81.81% <100.00%> (-0.06%) :arrow_down:
alibi_detect/saving/loading.py 93.83% <100.00%> (-0.03%) :arrow_down:
alibi_detect/datasets.py 68.69% <0.00%> (-1.31%) :arrow_down:

codecov-commenter avatar Sep 22 '22 11:09 codecov-commenter

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

The commits from f61af31 onwards contain three primary changes (based on discussion with @jklaise ):

  1. Subclassed tf models must have been built/called before they can be saved. It was decided that attempting to perform a dummy call prior to saving was risky. Instead, the ValueError that occurs in alibi_detect.saving._tensorflow.save_model is caught and re-raised with a more informative error message.
  2. For custom layers or subclassed models, various errors can occur during inference or cloning if custom objects are not properly provided at load time. The errors are often unclear, and surface from a variety of places e.g. in detector predict methods etc. alibi_detect.saving._tensorflow.load_model runs some basic checks on the loaded model and raises a warning if any problems are detected. This allows problems to be discovered when the detector is first loaded, instead of having to wait until prediction time.
  3. A more prominent warning about providing custom objects at load time is added to the docs.

ascillitoe avatar Jan 27 '23 13:01 ascillitoe

Postponing this to v0.12.0 so that we can combine it with deprecation of legacy saving and #723 .

ascillitoe avatar Jan 31 '23 10:01 ascillitoe