tfjs icon indicating copy to clipboard operation
tfjs copied to clipboard

Standard layers (like `tf.layers.dense`) fail to report their trainable weights, when used from within a custom layer

Open Vectorrent opened this issue 1 year ago • 4 comments
trafficstars

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow.js): true
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Arch Linux
  • TensorFlow.js installed from (npm or script link): NPM
  • TensorFlow.js version (use command below): 4.17.0

Describe the current behavior When building custom layers, it is often useful to use "standard" layer types like tf.layers.dense and tf.layers.LSTM, from inside of that layer. However, layers added in this way have 2 major problems:

  1. Their trainable parameters are not reported by model.summary().
  2. Their weights are not exported with model.save().

This is problematic for obvious reasons. The alternative is to use the this.addWeight() API; however, weights added in this way also have problems:

  1. It is wasteful and time-consuming, re-implementing layer types that already exist in the standard API.
  2. Weights added via this.addWeight() cannot use string activations, like mish and swish.

If there is already a supported way to integrate the weights from a standard layer like tf.layers.dense, from within a custom model - the method is not clear, from any of the documentation I've seen.

Describe the expected behavior I would expect weights used by the computational graph to be included in the model.summary()'s "trainable parameters" report. But, they are not.

___________________________________________________________________________________________________________________
Layer (type)               Input Shape                 Output shape             Param #     Receives inputs        
===================================================================================================================
inp-t0B (InputLayer)       [[null,null]]               [null,null]              0                                  
___________________________________________________________________________________________________________________
emb-gza (SharedEmbedding)  [[null,null]],[[null,null,2 multiple                 5091328     inp-t0B[0][0]          
                                                                                            mlp-adG[0][0]          
___________________________________________________________________________________________________________________
enc-RC2 (SinusoidalPositio [[null,null,256]]           [null,null,256]          0           emb-gza[0][0]          
___________________________________________________________________________________________________________________
attn-FBz (SelfAttention)   [[null,null,256]]           [null,null,256]          0           enc-RC2[0][0]          
___________________________________________________________________________________________________________________
mlp-3kL (MultiLayerPercept [[null,null,256]]           [null,null,256]          0           attn-FBz[0][0]         
___________________________________________________________________________________________________________________
attn-VZK (SelfAttention)   [[null,null,256]]           [null,null,256]          0           mlp-3kL[0][0]          
___________________________________________________________________________________________________________________
mlp-Jfy (MultiLayerPercept [[null,null,256]]           [null,null,256]          0           attn-VZK[0][0]         
___________________________________________________________________________________________________________________
attn-j0b (SelfAttention)   [[null,null,256]]           [null,null,256]          0           mlp-Jfy[0][0]          
___________________________________________________________________________________________________________________
mlp-oyK (MultiLayerPercept [[null,null,256]]           [null,null,256]          0           attn-j0b[0][0]         
___________________________________________________________________________________________________________________
attn-L1y (SelfAttention)   [[null,null,256]]           [null,null,256]          0           mlp-oyK[0][0]          
___________________________________________________________________________________________________________________
mlp-9r1 (MultiLayerPercept [[null,null,256]]           [null,null,256]          0           attn-L1y[0][0]         
___________________________________________________________________________________________________________________
attn-Yha (SelfAttention)   [[null,null,256]]           [null,null,256]          0           mlp-9r1[0][0]          
___________________________________________________________________________________________________________________
mlp-GV8 (MultiLayerPercept [[null,null,256]]           [null,null,256]          0           attn-Yha[0][0]         
___________________________________________________________________________________________________________________
attn-R5D (SelfAttention)   [[null,null,256]]           [null,null,256]          0           mlp-GV8[0][0]          
___________________________________________________________________________________________________________________
mlp-adG (MultiLayerPercept [[null,null,256]]           [null,null,256]          0           attn-R5D[0][0]         
===================================================================================================================
Total params: 5091328
Trainable params: 5091328
Non-trainable params: 0

Standalone code to reproduce the issue Add the following custom layer to any model, then call model.compile(), then model.summary(). You will see that it reports 0 trainable parameters:

class MultiLayerPerceptron extends tf.layers.Layer {
    constructor(config) {
        super({ ...config })
        this.units = config?.units || 256
        this.innerDim = config?.innerDim || 1024
    }

    build(inputShape) {
        this.inProj = tf.layers.dense({
            units: this.innerDim,
            inputDim: this.units,
            activation: 'relu'
        })
        this.outProj = tf.layers.dense({
            units: this.units,
            inputDim: inputShape,
            activation: 'linear'
        })
    }

    call(inputs, kwargs, training = false) {
        return tf.tidy(() => {
            inputs = Array.isArray(inputs) ? inputs[0] : inputs
            return this.outProj.apply(this.inProj.apply(inputs))
        })
    }

    computeOutputShape(inputShape) {
        return inputShape
    }

    getConfig() {
        return {
            ...super.getClass(),
            units: this.units,
            innerDim: this.innerDim
        }
    }

    static get className() {
        return 'MultiLayerPerceptron'
    }
}

tf.serialization.registerClass(MultiLayerPerceptron)

Other info / logs If there is a supported way to add the trainable parameters from tf.layers.dense() to my custom layer, please let me know!

Vectorrent avatar Apr 19 '24 21:04 Vectorrent

So, it does look like tf.layers.activation should work here, as a replacement to the string activations. I hadn't noticed this layer type before: https://js.tensorflow.org/api/latest/#layers.activation

The other stuff is still a problem, so far as I can see.

Vectorrent avatar Apr 21 '24 18:04 Vectorrent

You haven't registered the weights from the child layers to the parent later In the build method.

this.trainableWeights = [...this.childLayer.trainableWeights]

Same goes for nonTrainableWeights

We really need some documentation written up for the layers API, especially the rnncell base class

lukemovement avatar May 21 '24 05:05 lukemovement

Thanks for the suggestion, @lukemovement. Your suggestion does work, however you also need to build each of the child layers:

this.inProj.build(inputShape)
this.outProj.build(inputShape)

I actually discovered this trick previously, and it also has problems. The issue relates to restoring a model from disk:

/home/crow/Repos/ode/node_modules/@tensorflow/tfjs-layers/dist/tf-layers.node.js:273
        var _this = _super.call(this, message) || this;
                           ^

ValueError: Duplicate weight name: glu-TLd/kernel
    at new ValueError (/home/crow/Repos/ode/node_modules/@tensorflow/tfjs-layers/dist/tf-layers.node.js:273:28)
    at Container.loadWeights (/home/crow/Repos/ode/node_modules/@tensorflow/tfjs-layers/dist/tf-layers.node.js:21823:35)
    at /home/crow/Repos/ode/node_modules/@tensorflow/tfjs-layers/dist/tf-layers.node.js:25792:27
    at step (/home/crow/Repos/ode/node_modules/@tensorflow/tfjs-layers/dist/tf-layers.node.js:159:27)
    at Object.next (/home/crow/Repos/ode/node_modules/@tensorflow/tfjs-layers/dist/tf-layers.node.js:108:53)
    at fulfilled (/home/crow/Repos/ode/node_modules/@tensorflow/tfjs-layers/dist/tf-layers.node.js:89:28)

Node.js v18.18.2

When you add child layers to trainableWeights like this, every child layer will inherit the name of its parent (with a suffix):

  LayerVariable {
    dtype: 'float32',
    shape: [ 333 ],
    id: 30,
    originalName: 'glu-FpC/bias',
    name: 'glu-FpC/bias_2',
    trainable_: true,
    constraint: null,
    val: Variable {
      kept: false,
      isDisposedInternal: false,
      shape: [Array],
      dtype: 'float32',
      size: 333,
      strides: [],
      dataId: {},
      id: 49,
      rankType: '1',
      trainable: true,
      name: 'glu-FpC/bias_2'
    }
  }

The end result causes issues when attempting to restore those child layers from a checkpoint. I thought it might be possible to fix this with the getWeights() and setWeights() method, but couldn't find an immediate solution:

    getWeights() {
        return this.trainableWeights.map((weights) => weights.read())
    }

    setWeights(weights) {
        this.inProj.kernel.write(weights[0])
        this.inProj.bias.write(weights[1])
        this.outProj.kernel.write(weights[2])
        this.outProj.bias.write(weights[3])
    }

Anyway, I hope that additional context helps.

If you can point me to somewhere I could contribute to the docs, I might be able to write a tutorial or something. I've probably built 100 custom layers at this point.

Vectorrent avatar May 27 '24 14:05 Vectorrent

Thanks for the suggestion, @lukemovement. Your suggestion does work, however you also need to build each of the child layers:

this.inProj.build(inputShape)
this.outProj.build(inputShape)

I actually discovered this trick previously, and it also has problems. The issue relates to restoring a model from disk:

/home/crow/Repos/ode/node_modules/@tensorflow/tfjs-layers/dist/tf-layers.node.js:273
        var _this = _super.call(this, message) || this;
                           ^

ValueError: Duplicate weight name: glu-TLd/kernel
    at new ValueError (/home/crow/Repos/ode/node_modules/@tensorflow/tfjs-layers/dist/tf-layers.node.js:273:28)
    at Container.loadWeights (/home/crow/Repos/ode/node_modules/@tensorflow/tfjs-layers/dist/tf-layers.node.js:21823:35)
    at /home/crow/Repos/ode/node_modules/@tensorflow/tfjs-layers/dist/tf-layers.node.js:25792:27
    at step (/home/crow/Repos/ode/node_modules/@tensorflow/tfjs-layers/dist/tf-layers.node.js:159:27)
    at Object.next (/home/crow/Repos/ode/node_modules/@tensorflow/tfjs-layers/dist/tf-layers.node.js:108:53)
    at fulfilled (/home/crow/Repos/ode/node_modules/@tensorflow/tfjs-layers/dist/tf-layers.node.js:89:28)

Node.js v18.18.2

When you add child layers to trainableWeights like this, every child layer will inherit the name of its parent (with a suffix):

  LayerVariable {
    dtype: 'float32',
    shape: [ 333 ],
    id: 30,
    originalName: 'glu-FpC/bias',
    name: 'glu-FpC/bias_2',
    trainable_: true,
    constraint: null,
    val: Variable {
      kept: false,
      isDisposedInternal: false,
      shape: [Array],
      dtype: 'float32',
      size: 333,
      strides: [],
      dataId: {},
      id: 49,
      rankType: '1',
      trainable: true,
      name: 'glu-FpC/bias_2'
    }
  }

The end result causes issues when attempting to restore those child layers from a checkpoint. I thought it might be possible to fix this with the getWeights() and setWeights() method, but couldn't find an immediate solution:

    getWeights() {
        return this.trainableWeights.map((weights) => weights.read())
    }

    setWeights(weights) {
        this.inProj.kernel.write(weights[0])
        this.inProj.bias.write(weights[1])
        this.outProj.kernel.write(weights[2])
        this.outProj.bias.write(weights[3])
    }

Anyway, I hope that additional context helps.

If you can point me to somewhere I could contribute to the docs, I might be able to write a tutorial or something. I've probably built 100 custom layers at this point.

I'm unsure as to where this would best be placed. The weights are always in the same order on the layer so I use this. It works as long as you don't reach the maximum string length.

import * as tf from "@tensorflow/tfjs";
import { mkdir, readFile, writeFile } from "fs/promises";
import { resolve } from "path";

export const SaveModel = async ({
  model,
  dir,
}: {
  model: tf.LayersModel;
  dir: string;
}) => {
  for (const layer of model.layers) {
    const name = layer.name;

    if (0 === layer.weights.length) {
      continue;
    }

    const weights = JSON.stringify(
      layer.getWeights().map((weight) => weight.arraySync()),
    );

    await mkdir(dir, { recursive: true });
    await writeFile(resolve(dir, `${name}.json`), weights);
  }

  console.log(`Saved to ${dir}`);
};

export const LoadModel = async ({
  model,
  dir,
}: {
  model: tf.LayersModel;
  dir: string;
}) => {
  for (const layer of model.layers) {
    const name = layer.name;

    if (0 === layer.weights.length) {
      continue;
    }

    try {
      const weights = JSON.parse(
        await readFile(resolve(dir, `${name}.json`), "utf-8"),
      );

      const tensors = weights.map((weight: number[]) => tf.tensor(weight));

      layer.setWeights(tensors);
    } catch (e) {
      console.log(layer.name, (e as Error).message);
    }
  }

  console.warn(`Loaded from ${dir}`);
};

lukemovement avatar May 28 '24 00:05 lukemovement