cerebros-core-algorithm-alpha icon indicating copy to clipboard operation
cerebros-core-algorithm-alpha copied to clipboard

tandem-embeddings-with-freezable-weights

Open david-thrower opened this issue 1 year ago • 0 comments

Kind of issue: The botteck on the tandem embeddings may be that the embedding converges to an optima well before dense layers do. Consequently, the embedding gradients will zero out. This will cascade to zero out all the other gradients due to the chain rule.

A solution to try may look like this:

import tensorflow as tf import numpy as np

class TemporalEmbedding(tf.keras.layers.Layer):
    def __init__(self, vocab_size, embedding_dim, **kwargs):
        super(TemporalEmbedding, self).__init__(trainable=True)
        self.compute_gradient_for_n_epochs = 7
        self.train_counter = 0
        self.embedding_1 = tf.keras.layers.Embedding(vocab_size, embedding_dim, **kwargs)
        self.embedding_2 = tf.keras.layers.Embedding(vocab_size, embedding_dim, **kwargs)
        self.embedding_2.trainable = False
    def set_compute_gradient_for_n_epochs(self, n: int):
        self.compute_gradient_for_n_epochs = n
        print(f"Training this layer for only {self.compute_gradient_for_n_epochs} epochs")
    def call(self, inputs):
        print(f"starting state: {self.train_counter}")
        if self.train_counter < self.compute_gradient_for_n_epochs:
            print(f"Training weights for epoch {self.train_counter}")
            self.train_counter += 1
            return self.embedding_1(inputs)
        elif self.train_counter == self.compute_gradient_for_n_epochs:
            print(f"Setting trained weights to untrainable model (1) {self.train_counter}")
            self.train_counter += 1
            weights_0 =  self.embedding_1.get_weights()
            self.embedding_2.set_weights(weights_0)
            print("Returning weights from untrainable model")
            return self.embedding_2(inputs)
        else:
            print(f"Returning weights from untrainable model (2) {self.train_counter}")
            self.train_counter += 1
            return self.embedding_2(inputs)


input_layer = tf.keras.layers.Input(shape=(100,))
temporal_embedding_layer = TemporalEmbedding(vocab_size=10000, embedding_dim=64, input_length=10)
temporal_embedding_layer.set_compute_gradient_for_n_epochs(n=3)
temporal_embedding_layer_called = temporal_embedding_layer(input_layer)
flat = tf.keras.layers.Flatten()(temporal_embedding_layer_called)
output_layer = tf.keras.layers.Dense(10, activation='softmax')(flat)
model2 = tf.keras.Model(inputs=input_layer, outputs=output_layer)
opt = tf.keras.optimizers.Adam(learning_rate=0.001)
model2.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy'])


x_train = np.random.randint(10000, size=(200,100))
y_train = np.random.randint(2, size=(200,10))

model2.fit(x_train, y_train, epochs=20, batch_size=32)

Suggested Labels (If you don't know, that's ok): kind/enhancement

david-thrower avatar Dec 16 '23 01:12 david-thrower