keras-nlp icon indicating copy to clipboard operation
keras-nlp copied to clipboard

MaskedLMHead embeddings

Open AakashKumarNain opened this issue 8 months ago • 11 comments

I noticed a change that was introduced in the MaskedLMHead layer, and it broke my entire workflow. Earlier we had the signature for MaskedLMHead like this:

out = keras_nlp.layers.MaskedLMHead(
    embedding_weights=...
    activation=...,
)(encoded_tokens, masked_positions)

where we could have passed the embedding matrix directly to the layer. Right now, it expects an instance of reversible embedding layer. I have two questions around this change:

  1. Was there any specific reason for this change? Do we have some related discussions/issues which lead to this?
  2. If I had an embedding layer defined like below, what's the best way to pass the token embedding weights to MaskedLMHead
class PositionalEmbedding(layers.Layer):
    def __init__(self, seq_length, input_dim, output_dim, **kwargs):
        super().__init__(**kwargs)
        self.seq_length = seq_length
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.token_embeddings = layers.Embedding(input_dim=input_dim, output_dim=output_dim)
        self.position_embeddings = layers.Embedding(input_dim=seq_length, output_dim=output_dim)
        self.supports_masking = True
        
    def call(self, inputs):
        seq_length = tf.shape(inputs)[-1]
        positions = tf.range(start=0, limit=seq_length, delta=1)
        embedded_tokens = self.token_embeddings(inputs)
        embedded_positions = self.position_embeddings(positions)
        return embedded_tokens + embedded_positions

    def compute_mask(self, inputs, mask=None):
        return tf.math.not_equal(inputs, 0)

    def get_config(self):
        config = super().get_config()
        config.update({
            "seq_length": self.sequence_length,
            "input_dim" : self.input_dim,
            "output_dim": self.output_dim,
        })
        return config

AakashKumarNain avatar Oct 13 '23 10:10 AakashKumarNain

For the context: #1201 and subsequently https://github.com/keras-team/keras/issues/18419

shivance avatar Oct 13 '23 10:10 shivance

Thanks for the context @shivance

Do you have pointers for the second question as well?

cc: @mattdangerw

AakashKumarNain avatar Oct 13 '23 15:10 AakashKumarNain

@AakashKumarNain May be you could directly use : TokenAndPositionEmbedding

shivance avatar Oct 14 '23 04:10 shivance

That would require me to train the models again

AakashKumarNain avatar Oct 14 '23 07:10 AakashKumarNain

I tried replacing the custom PositionalEmbedding layer with the TokenAndPositionEmbedding layer, and the training stats are way off now. A few interesting points:

  1. Even with the default initialization, the training loss is very high at the start of the training, suggesting that the uniform initialization may not be the best one.
  2. The training loss decreases for a few epoch normally but then it spikes up suddenly, and never gets back on the track.

QQ: After we made that change, did we train some models to validate that it is working as expected, especially for pretraining tasks?

Screenshot 2023-10-16 at 4 23 20 PM

AakashKumarNain avatar Oct 16 '23 10:10 AakashKumarNain

@AakashKumarNain if the initialization isn't optimal could you try to override the embeddings_initializer arg and report back any strong findings?

Unfortunately e2e tests for training workflows are extremely expensive and problem specific.

jbischof avatar Oct 16 '23 18:10 jbischof

Unfortunately e2e tests for training workflows are extremely expensive and problem specific.

@jbischof I understand but common tasks like pretraining a BERT like model on a small dataset (e.g. Wikitext) don't take too much of time. The problem with not doing this exercise after a breaking change is that unless you train the model, you won't know if the changes broke something else. Loading weights and testing some outputs doesn't give a full picture about a breaking change.

AakashKumarNain avatar Oct 17 '23 17:10 AakashKumarNain

@AakashKumarNain sorry about the breakage here!

For the original question above... would it work to replace the token_embedding with a keras_nlp.layers. ReversibleEmbedding, and pass that token embedding to the masked language model head?

class PositionalEmbedding(layers.Layer):
    def __init__(self, seq_length, input_dim, output_dim, **kwargs):
        super().__init__(**kwargs)
        self.seq_length = seq_length
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.token_embeddings = keras_nlp.layers.ReversibleEmbedding(input_dim=input_dim, output_dim=output_dim)
        self.position_embeddings = layers.Embedding(input_dim=seq_length, output_dim=output_dim)
        self.supports_masking = True
...
embedding = PositionalEmbedding(...)
masked_lm_head = keras_nlp.layers.MaskedLMHead(
    embedding=embedding.token_embeddings,
    activation=...,
)(encoded_tokens, masked_positions)

If your custom layer is registered as serializable, you should be able to just reload the old model I think. But hard to know without knowing the full way you are serializing. You could also consider doing old_model.save_weights(...) with the old package version, and just load via load_weights into a new model.

mattdangerw avatar Oct 17 '23 23:10 mattdangerw

For some general context, we did decide to move away from the old way of passing embedding weights to the MaskedLMHead because it would not save correctly with the upcoming Keras 3. Keras 3 will support deduplicating saved objects at the layer level, but not at the weights level, so this could lead to some unexpected behavior from the layer after a round trip of saving and loading.

In general we try to avoid breakages like this, but Keras 3 is a bit of an exception in that it is a large change and some API incompatibilities are inevitable.

We could however do a better job of documenting any known breaking changes when they occur in our release notes. This is a good reminder. And in general this sort of churn should die down after we release the next major version of Keras.

mattdangerw avatar Oct 17 '23 23:10 mattdangerw

Thanks for the detailed response @mattdangerw

would it work to replace the token_embedding with a keras_nlp.layers. ReversibleEmbedding

I can try that, and will report back. Btw directly using TokenAndPositionEmbedding layer from keras_nlp didn't give me the desired results as pointed above.

In general we try to avoid breakages like this, but Keras 3 is a bit of an exception in that it is a large change and some API incompatibilities are inevitable.

I am aware of this. In fact, I was writing this code to add it to thecode examples with the JAX backend. So, this is a bit of irony. When I noticed the change, I was surprised because I was under the impression that going forward, we will maintain backward compatibility to the fullest. I missed the issue linked in the Keras repo for the same. Let me try retraining the whole thing again.

PS: I still recommend testing e2e for breaking changes

AakashKumarNain avatar Oct 18 '23 18:10 AakashKumarNain

Sounds good keep us posted!

And yeah the TokenAndPositionEmbedding initialization is an interesting question. I don't think there is any way to guarantee stable training performance with arbitrary architectures and optimizers, that is a non-goal. The fact that we use the same initializer as keras.layers.Embedding is good for the overall consistency of Keras.

If you are training a transformer from scratch, I would strongly recommend supplying initializers to most of your layers that you test for your particular arch/optimization strategy. You could try keras.initializers.TruncatedNormal(stddev=0.02) as a magic constant from BERT that has been copied by many LLMs since.

I would consider it a bug that we don't support setting the position embedding initializer and token embedding initializer separately. We should probably extend out init arguments to allow this.

PS: I still recommend testing e2e for breaking changes

An integration test asserting the stability for a bit of training with a from scratch transformer is a good idea, I will open up an issue. Though the biggest fish we have to fry is going to be continuous multi-framework benchmarking for preset models as we launch Keras 3.

mattdangerw avatar Oct 18 '23 20:10 mattdangerw