keras icon indicating copy to clipboard operation
keras copied to clipboard

MultiHeadAttention instances within custom layers do not receive weights when loaded with SavedModel format

Open SirDavidLudwig opened this issue 2 years ago • 7 comments

Please go to TF Forum for help and support:

https://discuss.tensorflow.org/tag/keras

If you open a GitHub issue, here is our policy:

It must be a bug, a feature request, or a significant problem with the documentation (for small docs fixes please send a PR instead). The form below must be filled out.

Here's why we have that policy:.

Keras developers respond to issues. We want to focus on work that benefits the whole community, e.g., fixing bugs and adding features. Support only helps individuals. GitHub also notifies thousands of people when issues are filed. We want them to see you communicating an interesting problem, rather than being redirected to Stack Overflow.

System information.

  • Have I written custom code (as opposed to using a stock example script provided in Keras): Yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux Ubuntu 20.04 LTS
  • TensorFlow installed from (source or binary): binary (conda install)
  • TensorFlow version (use command below): 2.8.1
  • Python version: 3.10.4
  • Bazel version (if compiling from source): N/A
  • GPU model and memory: NVIDIA GeForce RTX 3090 24GB
  • Exact command to reproduce: N/A

You can collect some of this information using our environment capture script:

https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh

You can obtain the TensorFlow version with: python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"

Describe the problem.

When a MultiHeadAttention layer is nested inside of a custom layer, it does not properly receive the weights when loaded from a model using the SavedModel format. This causes the loaded weights for these nested MHA layers to differ from the saved weights. This behavior does not exist when saving/loading via .h5.

Describe the current behavior.

When loading a Keras model saved with the SavedModel format, MultiHeadAttention (MHA) instances contained within custom layers do not properly receive their loaded weights. As a result, the loaded models do not match the saved models as the nested MHA layers have different weights than that of the saved models.

Describe the expected behavior.

As with other nested layers, a loaded model should have identical weights to that of the saved model. This seems to be the case with other layers when nested inside of a custom layer class as demonstrated in the provided notebook. For example, if a Dense layer is nested within a custom layer, its weights are properly loaded and the weights between the saved and loaded models match.

Contributing.

  • Do you want to contribute a PR? (yes/no): no
  • If yes, please read this page for instructions
  • Briefly describe your candidate solution(if contributing):

Standalone code to reproduce the issue.

Provide a reproducible test case that is the bare minimum necessary to generate the problem. If possible, please share a link to Colab/Jupyter/any notebook.

https://gist.github.com/SirDavidLudwig/9c7fe36e4513ef040959ffd3178c1111

Source code / logs.

Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached. Try to provide a reproducible test case that is the bare minimum necessary to generate the problem.

N/A

SirDavidLudwig avatar Jul 15 '22 19:07 SirDavidLudwig

@gadagashwini, I was able to reproduce the issue on tensorflow v2.8, v2.9 and nightly. Kindly find the gist of it here.

tilakrayal avatar Jul 18 '22 06:07 tilakrayal

Thanks for the detailed bug report, seems to be a quite important issue since MHA is core API.

qlzh727 avatar Jul 20 '22 18:07 qlzh727

Assign to Kathy for more inputs.

qlzh727 avatar Jul 20 '22 18:07 qlzh727

The reason why this is failing is that the MHA layer has extra instructions when deserialized with from_config, which isn't called when initialized using num_heads and key_dim: https://github.com/keras-team/keras/blob/v2.9.0/keras/layers/attention/multi_head_attention.py#L303

Without this line, the layer is marked as unbuilt (_built_from_signature = False), so it'll try to create new variables when called. This is why the layer appears to have newly initialized weights, instead of the checkpointed weights.

This is the correct way to create a custom layer with an MHA layer:

class MyCustomMhaLayer(keras.layers.Layer):
    def __init__(self, embed_dim=None, num_heads=None, mha=None, **kwargs):
        super().__init__(**kwargs)
        if mha is None:
          self.mha = keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
        else:
          self.mha = mha
        
    def call(self, x, training=None):
        return self.mha(x, x, x, training=training)
    
    def get_config(self):
        config = super().get_config()
        config.update({
            "mha": tf.keras.layers.serialize(self.mha)
        })
        return config    

    @classmethod
    def from_config(cls, config):
      config['mha'] = tf.keras.layers.deserialize(config['mha'])
      return super().from_config(config)

@scottzhu, this could be useful to put in the MHA docstring, could you see if anyone on the Keras team has bandwidth to do this?

k-w-w avatar Jul 20 '22 22:07 k-w-w

+@rchao This kind of issue is somewhat common, anyone who tries to create a subclassed MHA layer will run into it. The new idempotent saving format will also see it, as it is entirely config-based. Is there a way to cleanly solve this problem?

(e.g. letting the user know if the from_config isn't defined correctly)

k-w-w avatar Jul 20 '22 22:07 k-w-w

@k-w-w If we have to serialize the layer ourselves in this manner, are we therefore required to supply the MHA layer to the constructor, or is there another way to keep everything self-contained?

SirDavidLudwig avatar Jul 21 '22 15:07 SirDavidLudwig

@SirDavidLudwig In code snippet in my previous comment, you can either pass embed_dim and num_heads, or the mha layer into the constructor. The mha argument is needed only for from_config()

k-w-w avatar Jul 21 '22 20:07 k-w-w