keras icon indicating copy to clipboard operation
keras copied to clipboard

Sharing weights across layers in keras 3 [feature request]

Open nhuet opened this issue 1 year ago • 3 comments

It seems that sharing weights is not possible anymore afterwards in keras 3. We should instead share layers as explained here.

But I have a usecase where I need to share a weight

  • after init/build
  • without sharing a layer

In my usecase, I transform a model by splitting activations out of each layer, that means a Dense(3, activation="relu") is transformed in a Dense(3) + Activation layer. But I need

  • to let the original model unchanged (so i cannot just remove the activation from the original layer)
  • to share the original weights so that further training of the original model will impact also the splitted layers and thus the converted model
  • preferably the resulting layers to still be simple keras layers (like Dense and not a custom SplittedDense new class)

For now I have a solution but that use private attribute since by design this is currently not possible in keras 3.

Here is an example that works for sharing kernel (I actually will use something more generic to share any weight, but this is simpler to look at):

from keras.layers import Input, Dense


def share_kernel_and_build(layer1, layer2):
    # Check the layer1 is built and the layer2 is not built
    if not layer1.built:
        raise ValueError("The first layer must already be built for sharing its kernel.")
    if layer2.built:
        raise ValueError("The second layer must not be built to get the kernel of another layer")
    # Check that input exists really (ie that the layer has already been called on a symbolic KerasTensor
    input = layer1.input  # will raise a ValueError if not existing

    # store the kernel as a layer2 variable before build (ie before the lock of layer2's weights)
    layer2.kernel = layer1.kernel
    # build the layer
    layer2(input)
    # overwrite the newly generated kernel
    kernel_to_drop = layer2.kernel
    layer2.kernel = layer1.kernel
    # untrack the not used anymore kernel  (oops: using a private attribute!)
    layer2._tracker.untrack(kernel_to_drop)


layer1 = Dense(3)
input = Input((1,))
output = layer1(input)
layer2 = Dense(3)

share_kernel_and_build(layer1, layer2)

assert layer2.kernel is layer1.kernel
assert len(layer2.weights) == 2

Notes:

  • doing layer2.kernel = layer1.kernel after build will raise an error because of the lock.
  • doing afterwards allows to modify it again after the build because the variable being already tracked, this does not go into add_to_store
  • not untracking the unused kernel will result in an additional weight tracked by the layer

nhuet avatar Nov 23 '23 13:11 nhuet

A simpler solution to your problem would be:

  1. Instantiate the new Dense layer, e.g. dense = Dense.from_config(...). (It doesn't have weights at that time)
  2. Set dense.kernel = old_layer.kernel, dense.bias = old_layer.bias, dense.built = True
  3. Just use the layer -- no new weights will be created since the layer is already built

fchollet avatar Nov 30 '23 00:11 fchollet

Nice! But are we sure that the build() method does only create the weights? Perhaps i will miss something else by skipping build() ? I would like a solution that works with any layer. By setting self.built = True, I skip the build() and thus do not overwrite the weights, but is there anything else that could be important not to bypass so that the call() works ? At least, it seems build() sets also input_spec attribute, but perhaps this will not be too much of a loss (and i can also copy it from previous layer)

nhuet avatar Dec 04 '23 11:12 nhuet

A simpler solution to your problem would be:

  1. Instantiate the new Dense layer, e.g. dense = Dense.from_config(...). (It doesn't have weights at that time)
  2. Set dense.kernel = old_layer.kernel, dense.bias = old_layer.bias, dense.built = True
  3. Just use the layer -- no new weights will be created since the layer is already built

It does not work anymore from keras 3.0.3 since Dense.kernel is now a property not settable...

nhuet avatar Jan 22 '24 13:01 nhuet

We'll add a setter for the kernel.

fchollet avatar Apr 12 '24 05:04 fchollet

Thx!

nhuet avatar Apr 12 '24 14:04 nhuet

The setter thing turned out to be problematic. What I would recommend is just direct setting but use ._kernel instead of .kernel.

Ref: https://github.com/keras-team/keras/pull/19469

fchollet avatar Apr 12 '24 20:04 fchollet