tfjs icon indicating copy to clipboard operation
tfjs copied to clipboard

T5 Text-to-Text Transformer

Open lukemovement opened this issue 1 year ago • 0 comments

System information

  • TensorFlow.js version (you are using): 4.19.0
  • Are you willing to contribute it (Yes/No): yes

Describe the feature and the current behavior/state. T5 Text-to-Text Transformer

Will this change the current API? How? N/A

Who will benefit with this feature? Anyone looking to use the layer.

Any Other info. I have, to the best of my knowledge, built the transformer that was introduced in the paper "Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer" by Colin Raffel et al. in 2020. I'm looking to commit this however I may need some assistance bringing it inline with project standard.

Am I in the ball park or is this a swing and a miss?

import * as tf from "@tensorflow/tfjs";
import type { DenseLayerArgs } from "@tensorflow/tfjs-layers/dist/layers/core";

export interface T5TransformerLayerArgs {
  name?: string;
  numHeads: number;

  input1: {
    queriesArgs: Omit<DenseLayerArgs, "activation" | "useBias">;
    keysArgs: Omit<DenseLayerArgs, "activation" | "useBias">;
    valuesArgs: Omit<DenseLayerArgs, "activation" | "useBias">;
    outputArgs: Omit<DenseLayerArgs, "units" | "activation" | "useBias">;
    forwardFeed1Args: Omit<DenseLayerArgs, "activation" | "useBias">;
    forwardFeed2Args: Omit<DenseLayerArgs, "units" | "activation" | "useBias">;
  };

  input2: {
    queriesArgs: Omit<DenseLayerArgs, "activation" | "useBias">;
    keysArgs: Omit<DenseLayerArgs, "activation" | "useBias">;
    valuesArgs: Omit<DenseLayerArgs, "activation" | "useBias">;
    outputArgs: Omit<DenseLayerArgs, "units" | "activation" | "useBias">;
  };

  output: {
    queriesArgs: Omit<DenseLayerArgs, "activation" | "useBias">;
    keysArgs: Omit<DenseLayerArgs, "activation" | "useBias">;
    valuesArgs: Omit<DenseLayerArgs, "activation" | "useBias">;
    outputArgs: Omit<DenseLayerArgs, "units" | "activation" | "useBias">;
    forwardFeed1Args: Omit<DenseLayerArgs, "activation" | "useBias">;
    forwardFeed2Args: Omit<DenseLayerArgs, "units" | "activation" | "useBias">;
  };
}

export class _T5Transformer extends tf.layers.Layer {
  private options: T5TransformerLayerArgs;

  private input1KeysProjection!: tf.layers.Layer;
  private input1ValuesProjection!: tf.layers.Layer;
  private input1QueriesProjection!: tf.layers.Layer;
  private input1OutputProjection!: tf.layers.Layer;

  private input2KeysProjection!: tf.layers.Layer;
  private input2ValuesProjection!: tf.layers.Layer;
  private input2QueriesProjection!: tf.layers.Layer;
  private input2OutputProjection!: tf.layers.Layer;

  private outputKeysProjection!: tf.layers.Layer;
  private outputValuesProjection!: tf.layers.Layer;
  private outputQueriesProjection!: tf.layers.Layer;
  private outputOutputProjection!: tf.layers.Layer;

  private input1ForwardFeed1!: tf.layers.Layer;
  private input1Activation!: tf.layers.Layer;
  private input1ForwardFeed2!: tf.layers.Layer;

  private outputForwardFeed1!: tf.layers.Layer;
  private outputActivation!: tf.layers.Layer;
  private outputForwardFeed2!: tf.layers.Layer;

  private normalizeLayer!: tf.layers.Layer;

  constructor(options: T5TransformerLayerArgs) {
    super(options);

    this.options = options;
  }

  build(inputShape: tf.Shape | tf.Shape[]): void {
    if (Array.isArray(inputShape[0])) {
      inputShape = inputShape[0];
    }

    const dims = inputShape[inputShape.length - 1] as number;

    this.input1KeysProjection = tf.layers.dense({
      ...this.options.input1.keysArgs,
      useBias: false,
    });

    this.input1QueriesProjection = tf.layers.dense({
      ...this.options.input1.queriesArgs,
      useBias: false,
    });

    this.input1ValuesProjection = tf.layers.dense({
      ...this.options.input1.valuesArgs,
      useBias: false,
    });

    this.input1OutputProjection = tf.layers.dense({
      ...this.options.input1.outputArgs,
      useBias: false,
      units: dims,
    });

    this.input1ForwardFeed1 = tf.layers.dense({
      ...this.options.input1.forwardFeed1Args,
      useBias: false,
    });

    this.input1Activation = tf.layers.activation({
      activation: "gelu",
    });

    this.input1ForwardFeed2 = tf.layers.dense({
      ...this.options.input1.forwardFeed2Args,
      units: dims,
      useBias: false,
    });

    this.input2KeysProjection = tf.layers.dense({
      ...this.options.input2.keysArgs,
      useBias: false,
    });

    this.input2QueriesProjection = tf.layers.dense({
      ...this.options.input2.queriesArgs,
      useBias: false,
    });

    this.input2ValuesProjection = tf.layers.dense({
      ...this.options.input2.valuesArgs,
      useBias: false,
    });

    this.input2OutputProjection = tf.layers.dense({
      ...this.options.input2.outputArgs,
      useBias: false,
      units: dims,
    });

    this.outputKeysProjection = tf.layers.dense({
      ...this.options.output.keysArgs,
      useBias: false,
    });

    this.outputQueriesProjection = tf.layers.dense({
      ...this.options.output.queriesArgs,
      useBias: false,
    });

    this.outputValuesProjection = tf.layers.dense({
      ...this.options.output.valuesArgs,
      useBias: false,
    });

    this.outputOutputProjection = tf.layers.dense({
      ...this.options.output.outputArgs,
      useBias: false,
      units: dims,
    });

    this.outputForwardFeed1 = tf.layers.dense({
      ...this.options.output.forwardFeed1Args,
      useBias: false,
    });

    this.outputActivation = tf.layers.activation({
      activation: "gelu",
    });

    this.outputForwardFeed2 = tf.layers.dense({
      ...this.options.output.forwardFeed2Args,
      units: dims,
      useBias: false,
    });

    this.normalizeLayer = tf.layers.batchNormalization({
      trainable: false,
    });

    this.input1KeysProjection.build(inputShape);
    this.input1QueriesProjection.build(inputShape);
    this.input1ValuesProjection.build(inputShape);
    this.input1OutputProjection.build(
      this.input1ValuesProjection.computeOutputShape(inputShape),
    );

    this.input1ForwardFeed1.build(inputShape);
    this.input1Activation.build(
      this.input1ForwardFeed1.computeOutputShape(inputShape),
    );
    this.input1ForwardFeed2.build(
      this.input1ForwardFeed1.computeOutputShape(inputShape),
    );

    this.input2KeysProjection.build(inputShape);
    this.input2QueriesProjection.build(inputShape);
    this.input2ValuesProjection.build(inputShape);
    this.input2OutputProjection.build(
      this.input2ValuesProjection.computeOutputShape(inputShape),
    );

    this.outputKeysProjection.build(inputShape);
    this.outputQueriesProjection.build(inputShape);
    this.outputValuesProjection.build(inputShape);
    this.outputOutputProjection.build(
      this.outputValuesProjection.computeOutputShape(inputShape),
    );

    this.outputForwardFeed1.build(inputShape);
    this.outputActivation.build(
      this.outputForwardFeed1.computeOutputShape(inputShape),
    );
    this.outputForwardFeed2.build(
      this.outputForwardFeed1.computeOutputShape(inputShape),
    );

    this.normalizeLayer.build(inputShape);

    this.trainableWeights = [
      ...this.input1KeysProjection.trainableWeights,
      ...this.input1QueriesProjection.trainableWeights,
      ...this.input1ValuesProjection.trainableWeights,
      ...this.input1OutputProjection.trainableWeights,
      ...this.input1ForwardFeed1.trainableWeights,
      ...this.input1ForwardFeed2.trainableWeights,
      ...this.input2KeysProjection.trainableWeights,
      ...this.input2QueriesProjection.trainableWeights,
      ...this.input2ValuesProjection.trainableWeights,
      ...this.input2OutputProjection.trainableWeights,
      ...this.outputKeysProjection.trainableWeights,
      ...this.outputQueriesProjection.trainableWeights,
      ...this.outputValuesProjection.trainableWeights,
      ...this.outputOutputProjection.trainableWeights,
      ...this.outputForwardFeed1.trainableWeights,
      ...this.outputForwardFeed2.trainableWeights,
      ...this.normalizeLayer.trainableWeights,
    ];

    this.nonTrainableWeights = [
      ...this.input1KeysProjection.nonTrainableWeights,
      ...this.input1QueriesProjection.nonTrainableWeights,
      ...this.input1ValuesProjection.nonTrainableWeights,
      ...this.input1OutputProjection.nonTrainableWeights,
      ...this.input1ForwardFeed1.nonTrainableWeights,
      ...this.input1ForwardFeed2.nonTrainableWeights,
      ...this.input2KeysProjection.nonTrainableWeights,
      ...this.input2QueriesProjection.nonTrainableWeights,
      ...this.input2ValuesProjection.nonTrainableWeights,
      ...this.input2OutputProjection.nonTrainableWeights,
      ...this.outputKeysProjection.nonTrainableWeights,
      ...this.outputQueriesProjection.nonTrainableWeights,
      ...this.outputValuesProjection.nonTrainableWeights,
      ...this.outputOutputProjection.nonTrainableWeights,
      ...this.outputForwardFeed1.nonTrainableWeights,
      ...this.outputForwardFeed2.nonTrainableWeights,
      ...this.normalizeLayer.nonTrainableWeights,
    ];

    this.built = true;
  }

  call(input: tf.Tensor | tf.Tensor[]): tf.Tensor | tf.Tensor[] {
    return tf.tidy(() => {
      const [input1, input2] = this.getInputs(input);

      const encodeInput1 = this.encodeInput1(input1);
      const encodeInput2 = this.encodeInput2(input2);

      return this.decodeOutput(encodeInput1, encodeInput2);
    });
  }

  private mask: {
    shape: string;
    tensor: tf.Tensor;
  }[] = [];

  private attention(
    query: tf.Tensor,
    key: tf.Tensor,
    value: tf.Tensor,
    mask: boolean,
  ): tf.Tensor {
    return tf.tidy(() => {
      const keyHeads = this.split(key);
      const valueHeads = this.split(value);
      const queryHeads = this.split(query);

      const depth = keyHeads.shape[1] as number;

      let QK = tf.matMul(queryHeads, keyHeads, false, true);

      if (mask) {
        if (!this.mask.find((m) => m.shape === QK.shape.toString())) {
          const mask = tf.keep(tf.linalg.bandPart(tf.onesLike(QK), -1, 0));
          this.mask.push({ shape: QK.shape.toString(), tensor: mask });
        }

        const masked = QK.mul(
          this.mask.find((m) => m.shape === QK.shape.toString())!.tensor,
        );

        QK = QK.mul(masked);
      }

      const dk = tf.scalar(depth);
      const QKScaled = QK.div(tf.sqrt(dk));

      const attentionWeights = tf.softmax(QKScaled, -1);

      const output = tf.matMul(attentionWeights, valueHeads);

      return this.merge(output);
    });
  }

  private encodeInput1(input: tf.Tensor): tf.Tensor {
    return tf.tidy(() => {
      const keys = this.input1KeysProjection.apply(input) as tf.Tensor;
      const queries = this.input1QueriesProjection.apply(input) as tf.Tensor;
      const values = this.input1ValuesProjection.apply(input) as tf.Tensor;

      const attentionWeights = this.attention(
        queries,
        keys,
        values,
        false,
      ) as tf.Tensor;

      const attention = this.input1OutputProjection.apply(
        attentionWeights,
      ) as tf.Tensor;

      const addAttention = tf.add(attention, input) as tf.Tensor;
      const normAttention = this.normalizeLayer.apply(
        addAttention,
      ) as tf.Tensor;

      const feed1 = this.input1ForwardFeed1.apply(normAttention) as tf.Tensor;
      const active = this.input1Activation.apply(feed1) as tf.Tensor;
      const feed2 = this.input1ForwardFeed2.apply(active) as tf.Tensor;

      const addFeed = tf.add(feed2, normAttention) as tf.Tensor;
      const normFeed = this.normalizeLayer.apply(addFeed) as tf.Tensor;

      return normFeed;
    });
  }

  private encodeInput2(input: tf.Tensor): tf.Tensor {
    return tf.tidy(() => {
      const keys = this.input2KeysProjection.apply(input) as tf.Tensor;
      const queries = this.input2QueriesProjection.apply(input) as tf.Tensor;
      const values = this.input2ValuesProjection.apply(input) as tf.Tensor;

      const attentionWeights = this.attention(
        queries,
        keys,
        values,
        true,
      ) as tf.Tensor;

      const attention = this.input2OutputProjection.apply(
        attentionWeights,
      ) as tf.Tensor;

      const addAttention = tf.add(attention, input) as tf.Tensor;
      const normAttention = this.normalizeLayer.apply(
        addAttention,
      ) as tf.Tensor;

      return normAttention;
    });
  }

  private decodeOutput(
    encodeInput1: tf.Tensor,
    encodeInput2: tf.Tensor,
  ): tf.Tensor {
    const keys = this.outputKeysProjection.apply(encodeInput1) as tf.Tensor;
    const queries = this.outputQueriesProjection.apply(
      encodeInput2,
    ) as tf.Tensor;
    const values = this.outputValuesProjection.apply(encodeInput1) as tf.Tensor;

    const attentionWeights = this.attention(
      queries,
      keys,
      values,
      false,
    ) as tf.Tensor;

    const attention = this.outputOutputProjection.apply(
      attentionWeights,
    ) as tf.Tensor;

    const addAttention = tf.add(attention, encodeInput1) as tf.Tensor;
    const normAttention = this.normalizeLayer.apply(addAttention) as tf.Tensor;

    const feed1 = this.outputForwardFeed1.apply(normAttention) as tf.Tensor;
    const active = this.outputActivation.apply(feed1) as tf.Tensor;
    const feed2 = this.outputForwardFeed2.apply(active) as tf.Tensor;

    const addFeed = tf.add(feed2, normAttention) as tf.Tensor;
    const normFeed = this.normalizeLayer.apply(addFeed) as tf.Tensor;

    return normFeed;
  }

  private getInputs(input: tf.Tensor | tf.Tensor[]): [tf.Tensor, tf.Tensor] {
    let input1 = input as tf.Tensor;
    let input2 = input as tf.Tensor;

    if (Array.isArray(input) && input.length === 2) {
      input1 = input[0];
      input2 = input[1];
    } else if (Array.isArray(input) && input.length === 1) {
      input1 = input[0];
      input2 = input1;
    }

    return [input1, input2];
  }

  private split(input: tf.Tensor): tf.Tensor {
    const targetShape = [...input.shape];

    targetShape.splice(targetShape.length - 1, 0, this.options.numHeads);
    targetShape[targetShape.length - 1] =
      targetShape[targetShape.length - 1] / this.options.numHeads;

    return input.reshape(targetShape);
  }

  private merge(input: tf.Tensor): tf.Tensor {
    const targetShape = [...input.shape];

    targetShape.splice(targetShape.length - 2, 1);
    targetShape[targetShape.length - 1] =
      targetShape[targetShape.length - 1] * this.options.numHeads;

    return input.reshape(targetShape);
  }

  static className = "T5Transformer";

  computeOutputShape(inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[] {
    if (Array.isArray(inputShape[0])) {
      return inputShape[0];
    }

    return inputShape;
  }
}

tf.serialization.registerClass(_T5Transformer);

export const T5Transformer = (options: T5TransformerLayerArgs) => {
  return new _T5Transformer(options);
};

lukemovement avatar Jun 02 '24 13:06 lukemovement