tfjs icon indicating copy to clipboard operation
tfjs copied to clipboard

Standard layers (like `tf.layers.dense`) fail to report their trainable weights, when used from within a custom layer

Open Vectorrent opened this issue 4 months ago • 4 comments

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow.js): true
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Arch Linux
  • TensorFlow.js installed from (npm or script link): NPM
  • TensorFlow.js version (use command below): 4.17.0

Describe the current behavior When building custom layers, it is often useful to use "standard" layer types like tf.layers.dense and tf.layers.LSTM, from inside of that layer. However, layers added in this way have 2 major problems:

  1. Their trainable parameters are not reported by model.summary().
  2. Their weights are not exported with model.save().

This is problematic for obvious reasons. The alternative is to use the this.addWeight() API; however, weights added in this way also have problems:

  1. It is wasteful and time-consuming, re-implementing layer types that already exist in the standard API.
  2. Weights added via this.addWeight() cannot use string activations, like mish and swish.

If there is already a supported way to integrate the weights from a standard layer like tf.layers.dense, from within a custom model - the method is not clear, from any of the documentation I've seen.

Describe the expected behavior I would expect weights used by the computational graph to be included in the model.summary()'s "trainable parameters" report. But, they are not.

___________________________________________________________________________________________________________________
Layer (type)               Input Shape                 Output shape             Param #     Receives inputs        
===================================================================================================================
inp-t0B (InputLayer)       [[null,null]]               [null,null]              0                                  
___________________________________________________________________________________________________________________
emb-gza (SharedEmbedding)  [[null,null]],[[null,null,2 multiple                 5091328     inp-t0B[0][0]          
                                                                                            mlp-adG[0][0]          
___________________________________________________________________________________________________________________
enc-RC2 (SinusoidalPositio [[null,null,256]]           [null,null,256]          0           emb-gza[0][0]          
___________________________________________________________________________________________________________________
attn-FBz (SelfAttention)   [[null,null,256]]           [null,null,256]          0           enc-RC2[0][0]          
___________________________________________________________________________________________________________________
mlp-3kL (MultiLayerPercept [[null,null,256]]           [null,null,256]          0           attn-FBz[0][0]         
___________________________________________________________________________________________________________________
attn-VZK (SelfAttention)   [[null,null,256]]           [null,null,256]          0           mlp-3kL[0][0]          
___________________________________________________________________________________________________________________
mlp-Jfy (MultiLayerPercept [[null,null,256]]           [null,null,256]          0           attn-VZK[0][0]         
___________________________________________________________________________________________________________________
attn-j0b (SelfAttention)   [[null,null,256]]           [null,null,256]          0           mlp-Jfy[0][0]          
___________________________________________________________________________________________________________________
mlp-oyK (MultiLayerPercept [[null,null,256]]           [null,null,256]          0           attn-j0b[0][0]         
___________________________________________________________________________________________________________________
attn-L1y (SelfAttention)   [[null,null,256]]           [null,null,256]          0           mlp-oyK[0][0]          
___________________________________________________________________________________________________________________
mlp-9r1 (MultiLayerPercept [[null,null,256]]           [null,null,256]          0           attn-L1y[0][0]         
___________________________________________________________________________________________________________________
attn-Yha (SelfAttention)   [[null,null,256]]           [null,null,256]          0           mlp-9r1[0][0]          
___________________________________________________________________________________________________________________
mlp-GV8 (MultiLayerPercept [[null,null,256]]           [null,null,256]          0           attn-Yha[0][0]         
___________________________________________________________________________________________________________________
attn-R5D (SelfAttention)   [[null,null,256]]           [null,null,256]          0           mlp-GV8[0][0]          
___________________________________________________________________________________________________________________
mlp-adG (MultiLayerPercept [[null,null,256]]           [null,null,256]          0           attn-R5D[0][0]         
===================================================================================================================
Total params: 5091328
Trainable params: 5091328
Non-trainable params: 0

Standalone code to reproduce the issue Add the following custom layer to any model, then call model.compile(), then model.summary(). You will see that it reports 0 trainable parameters:

class MultiLayerPerceptron extends tf.layers.Layer {
    constructor(config) {
        super({ ...config })
        this.units = config?.units || 256
        this.innerDim = config?.innerDim || 1024
    }

    build(inputShape) {
        this.inProj = tf.layers.dense({
            units: this.innerDim,
            inputDim: this.units,
            activation: 'relu'
        })
        this.outProj = tf.layers.dense({
            units: this.units,
            inputDim: inputShape,
            activation: 'linear'
        })
    }

    call(inputs, kwargs, training = false) {
        return tf.tidy(() => {
            inputs = Array.isArray(inputs) ? inputs[0] : inputs
            return this.outProj.apply(this.inProj.apply(inputs))
        })
    }

    computeOutputShape(inputShape) {
        return inputShape
    }

    getConfig() {
        return {
            ...super.getClass(),
            units: this.units,
            innerDim: this.innerDim
        }
    }

    static get className() {
        return 'MultiLayerPerceptron'
    }
}

tf.serialization.registerClass(MultiLayerPerceptron)

Other info / logs If there is a supported way to add the trainable parameters from tf.layers.dense() to my custom layer, please let me know!

Vectorrent avatar Apr 19 '24 21:04 Vectorrent