tfjs
tfjs copied to clipboard
T5 Text-to-Text Transformer
System information
- TensorFlow.js version (you are using): 4.19.0
- Are you willing to contribute it (Yes/No): yes
Describe the feature and the current behavior/state.

Will this change the current API? How? N/A
Who will benefit with this feature? Anyone looking to use the layer.
Any Other info. I have, to the best of my knowledge, built the transformer that was introduced in the paper "Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer" by Colin Raffel et al. in 2020. I'm looking to commit this however I may need some assistance bringing it inline with project standard.
Am I in the ball park or is this a swing and a miss?
import * as tf from "@tensorflow/tfjs";
import type { DenseLayerArgs } from "@tensorflow/tfjs-layers/dist/layers/core";
export interface T5TransformerLayerArgs {
name?: string;
numHeads: number;
input1: {
queriesArgs: Omit<DenseLayerArgs, "activation" | "useBias">;
keysArgs: Omit<DenseLayerArgs, "activation" | "useBias">;
valuesArgs: Omit<DenseLayerArgs, "activation" | "useBias">;
outputArgs: Omit<DenseLayerArgs, "units" | "activation" | "useBias">;
forwardFeed1Args: Omit<DenseLayerArgs, "activation" | "useBias">;
forwardFeed2Args: Omit<DenseLayerArgs, "units" | "activation" | "useBias">;
};
input2: {
queriesArgs: Omit<DenseLayerArgs, "activation" | "useBias">;
keysArgs: Omit<DenseLayerArgs, "activation" | "useBias">;
valuesArgs: Omit<DenseLayerArgs, "activation" | "useBias">;
outputArgs: Omit<DenseLayerArgs, "units" | "activation" | "useBias">;
};
output: {
queriesArgs: Omit<DenseLayerArgs, "activation" | "useBias">;
keysArgs: Omit<DenseLayerArgs, "activation" | "useBias">;
valuesArgs: Omit<DenseLayerArgs, "activation" | "useBias">;
outputArgs: Omit<DenseLayerArgs, "units" | "activation" | "useBias">;
forwardFeed1Args: Omit<DenseLayerArgs, "activation" | "useBias">;
forwardFeed2Args: Omit<DenseLayerArgs, "units" | "activation" | "useBias">;
};
}
export class _T5Transformer extends tf.layers.Layer {
private options: T5TransformerLayerArgs;
private input1KeysProjection!: tf.layers.Layer;
private input1ValuesProjection!: tf.layers.Layer;
private input1QueriesProjection!: tf.layers.Layer;
private input1OutputProjection!: tf.layers.Layer;
private input2KeysProjection!: tf.layers.Layer;
private input2ValuesProjection!: tf.layers.Layer;
private input2QueriesProjection!: tf.layers.Layer;
private input2OutputProjection!: tf.layers.Layer;
private outputKeysProjection!: tf.layers.Layer;
private outputValuesProjection!: tf.layers.Layer;
private outputQueriesProjection!: tf.layers.Layer;
private outputOutputProjection!: tf.layers.Layer;
private input1ForwardFeed1!: tf.layers.Layer;
private input1Activation!: tf.layers.Layer;
private input1ForwardFeed2!: tf.layers.Layer;
private outputForwardFeed1!: tf.layers.Layer;
private outputActivation!: tf.layers.Layer;
private outputForwardFeed2!: tf.layers.Layer;
private normalizeLayer!: tf.layers.Layer;
constructor(options: T5TransformerLayerArgs) {
super(options);
this.options = options;
}
build(inputShape: tf.Shape | tf.Shape[]): void {
if (Array.isArray(inputShape[0])) {
inputShape = inputShape[0];
}
const dims = inputShape[inputShape.length - 1] as number;
this.input1KeysProjection = tf.layers.dense({
...this.options.input1.keysArgs,
useBias: false,
});
this.input1QueriesProjection = tf.layers.dense({
...this.options.input1.queriesArgs,
useBias: false,
});
this.input1ValuesProjection = tf.layers.dense({
...this.options.input1.valuesArgs,
useBias: false,
});
this.input1OutputProjection = tf.layers.dense({
...this.options.input1.outputArgs,
useBias: false,
units: dims,
});
this.input1ForwardFeed1 = tf.layers.dense({
...this.options.input1.forwardFeed1Args,
useBias: false,
});
this.input1Activation = tf.layers.activation({
activation: "gelu",
});
this.input1ForwardFeed2 = tf.layers.dense({
...this.options.input1.forwardFeed2Args,
units: dims,
useBias: false,
});
this.input2KeysProjection = tf.layers.dense({
...this.options.input2.keysArgs,
useBias: false,
});
this.input2QueriesProjection = tf.layers.dense({
...this.options.input2.queriesArgs,
useBias: false,
});
this.input2ValuesProjection = tf.layers.dense({
...this.options.input2.valuesArgs,
useBias: false,
});
this.input2OutputProjection = tf.layers.dense({
...this.options.input2.outputArgs,
useBias: false,
units: dims,
});
this.outputKeysProjection = tf.layers.dense({
...this.options.output.keysArgs,
useBias: false,
});
this.outputQueriesProjection = tf.layers.dense({
...this.options.output.queriesArgs,
useBias: false,
});
this.outputValuesProjection = tf.layers.dense({
...this.options.output.valuesArgs,
useBias: false,
});
this.outputOutputProjection = tf.layers.dense({
...this.options.output.outputArgs,
useBias: false,
units: dims,
});
this.outputForwardFeed1 = tf.layers.dense({
...this.options.output.forwardFeed1Args,
useBias: false,
});
this.outputActivation = tf.layers.activation({
activation: "gelu",
});
this.outputForwardFeed2 = tf.layers.dense({
...this.options.output.forwardFeed2Args,
units: dims,
useBias: false,
});
this.normalizeLayer = tf.layers.batchNormalization({
trainable: false,
});
this.input1KeysProjection.build(inputShape);
this.input1QueriesProjection.build(inputShape);
this.input1ValuesProjection.build(inputShape);
this.input1OutputProjection.build(
this.input1ValuesProjection.computeOutputShape(inputShape),
);
this.input1ForwardFeed1.build(inputShape);
this.input1Activation.build(
this.input1ForwardFeed1.computeOutputShape(inputShape),
);
this.input1ForwardFeed2.build(
this.input1ForwardFeed1.computeOutputShape(inputShape),
);
this.input2KeysProjection.build(inputShape);
this.input2QueriesProjection.build(inputShape);
this.input2ValuesProjection.build(inputShape);
this.input2OutputProjection.build(
this.input2ValuesProjection.computeOutputShape(inputShape),
);
this.outputKeysProjection.build(inputShape);
this.outputQueriesProjection.build(inputShape);
this.outputValuesProjection.build(inputShape);
this.outputOutputProjection.build(
this.outputValuesProjection.computeOutputShape(inputShape),
);
this.outputForwardFeed1.build(inputShape);
this.outputActivation.build(
this.outputForwardFeed1.computeOutputShape(inputShape),
);
this.outputForwardFeed2.build(
this.outputForwardFeed1.computeOutputShape(inputShape),
);
this.normalizeLayer.build(inputShape);
this.trainableWeights = [
...this.input1KeysProjection.trainableWeights,
...this.input1QueriesProjection.trainableWeights,
...this.input1ValuesProjection.trainableWeights,
...this.input1OutputProjection.trainableWeights,
...this.input1ForwardFeed1.trainableWeights,
...this.input1ForwardFeed2.trainableWeights,
...this.input2KeysProjection.trainableWeights,
...this.input2QueriesProjection.trainableWeights,
...this.input2ValuesProjection.trainableWeights,
...this.input2OutputProjection.trainableWeights,
...this.outputKeysProjection.trainableWeights,
...this.outputQueriesProjection.trainableWeights,
...this.outputValuesProjection.trainableWeights,
...this.outputOutputProjection.trainableWeights,
...this.outputForwardFeed1.trainableWeights,
...this.outputForwardFeed2.trainableWeights,
...this.normalizeLayer.trainableWeights,
];
this.nonTrainableWeights = [
...this.input1KeysProjection.nonTrainableWeights,
...this.input1QueriesProjection.nonTrainableWeights,
...this.input1ValuesProjection.nonTrainableWeights,
...this.input1OutputProjection.nonTrainableWeights,
...this.input1ForwardFeed1.nonTrainableWeights,
...this.input1ForwardFeed2.nonTrainableWeights,
...this.input2KeysProjection.nonTrainableWeights,
...this.input2QueriesProjection.nonTrainableWeights,
...this.input2ValuesProjection.nonTrainableWeights,
...this.input2OutputProjection.nonTrainableWeights,
...this.outputKeysProjection.nonTrainableWeights,
...this.outputQueriesProjection.nonTrainableWeights,
...this.outputValuesProjection.nonTrainableWeights,
...this.outputOutputProjection.nonTrainableWeights,
...this.outputForwardFeed1.nonTrainableWeights,
...this.outputForwardFeed2.nonTrainableWeights,
...this.normalizeLayer.nonTrainableWeights,
];
this.built = true;
}
call(input: tf.Tensor | tf.Tensor[]): tf.Tensor | tf.Tensor[] {
return tf.tidy(() => {
const [input1, input2] = this.getInputs(input);
const encodeInput1 = this.encodeInput1(input1);
const encodeInput2 = this.encodeInput2(input2);
return this.decodeOutput(encodeInput1, encodeInput2);
});
}
private mask: {
shape: string;
tensor: tf.Tensor;
}[] = [];
private attention(
query: tf.Tensor,
key: tf.Tensor,
value: tf.Tensor,
mask: boolean,
): tf.Tensor {
return tf.tidy(() => {
const keyHeads = this.split(key);
const valueHeads = this.split(value);
const queryHeads = this.split(query);
const depth = keyHeads.shape[1] as number;
let QK = tf.matMul(queryHeads, keyHeads, false, true);
if (mask) {
if (!this.mask.find((m) => m.shape === QK.shape.toString())) {
const mask = tf.keep(tf.linalg.bandPart(tf.onesLike(QK), -1, 0));
this.mask.push({ shape: QK.shape.toString(), tensor: mask });
}
const masked = QK.mul(
this.mask.find((m) => m.shape === QK.shape.toString())!.tensor,
);
QK = QK.mul(masked);
}
const dk = tf.scalar(depth);
const QKScaled = QK.div(tf.sqrt(dk));
const attentionWeights = tf.softmax(QKScaled, -1);
const output = tf.matMul(attentionWeights, valueHeads);
return this.merge(output);
});
}
private encodeInput1(input: tf.Tensor): tf.Tensor {
return tf.tidy(() => {
const keys = this.input1KeysProjection.apply(input) as tf.Tensor;
const queries = this.input1QueriesProjection.apply(input) as tf.Tensor;
const values = this.input1ValuesProjection.apply(input) as tf.Tensor;
const attentionWeights = this.attention(
queries,
keys,
values,
false,
) as tf.Tensor;
const attention = this.input1OutputProjection.apply(
attentionWeights,
) as tf.Tensor;
const addAttention = tf.add(attention, input) as tf.Tensor;
const normAttention = this.normalizeLayer.apply(
addAttention,
) as tf.Tensor;
const feed1 = this.input1ForwardFeed1.apply(normAttention) as tf.Tensor;
const active = this.input1Activation.apply(feed1) as tf.Tensor;
const feed2 = this.input1ForwardFeed2.apply(active) as tf.Tensor;
const addFeed = tf.add(feed2, normAttention) as tf.Tensor;
const normFeed = this.normalizeLayer.apply(addFeed) as tf.Tensor;
return normFeed;
});
}
private encodeInput2(input: tf.Tensor): tf.Tensor {
return tf.tidy(() => {
const keys = this.input2KeysProjection.apply(input) as tf.Tensor;
const queries = this.input2QueriesProjection.apply(input) as tf.Tensor;
const values = this.input2ValuesProjection.apply(input) as tf.Tensor;
const attentionWeights = this.attention(
queries,
keys,
values,
true,
) as tf.Tensor;
const attention = this.input2OutputProjection.apply(
attentionWeights,
) as tf.Tensor;
const addAttention = tf.add(attention, input) as tf.Tensor;
const normAttention = this.normalizeLayer.apply(
addAttention,
) as tf.Tensor;
return normAttention;
});
}
private decodeOutput(
encodeInput1: tf.Tensor,
encodeInput2: tf.Tensor,
): tf.Tensor {
const keys = this.outputKeysProjection.apply(encodeInput1) as tf.Tensor;
const queries = this.outputQueriesProjection.apply(
encodeInput2,
) as tf.Tensor;
const values = this.outputValuesProjection.apply(encodeInput1) as tf.Tensor;
const attentionWeights = this.attention(
queries,
keys,
values,
false,
) as tf.Tensor;
const attention = this.outputOutputProjection.apply(
attentionWeights,
) as tf.Tensor;
const addAttention = tf.add(attention, encodeInput1) as tf.Tensor;
const normAttention = this.normalizeLayer.apply(addAttention) as tf.Tensor;
const feed1 = this.outputForwardFeed1.apply(normAttention) as tf.Tensor;
const active = this.outputActivation.apply(feed1) as tf.Tensor;
const feed2 = this.outputForwardFeed2.apply(active) as tf.Tensor;
const addFeed = tf.add(feed2, normAttention) as tf.Tensor;
const normFeed = this.normalizeLayer.apply(addFeed) as tf.Tensor;
return normFeed;
}
private getInputs(input: tf.Tensor | tf.Tensor[]): [tf.Tensor, tf.Tensor] {
let input1 = input as tf.Tensor;
let input2 = input as tf.Tensor;
if (Array.isArray(input) && input.length === 2) {
input1 = input[0];
input2 = input[1];
} else if (Array.isArray(input) && input.length === 1) {
input1 = input[0];
input2 = input1;
}
return [input1, input2];
}
private split(input: tf.Tensor): tf.Tensor {
const targetShape = [...input.shape];
targetShape.splice(targetShape.length - 1, 0, this.options.numHeads);
targetShape[targetShape.length - 1] =
targetShape[targetShape.length - 1] / this.options.numHeads;
return input.reshape(targetShape);
}
private merge(input: tf.Tensor): tf.Tensor {
const targetShape = [...input.shape];
targetShape.splice(targetShape.length - 2, 1);
targetShape[targetShape.length - 1] =
targetShape[targetShape.length - 1] * this.options.numHeads;
return input.reshape(targetShape);
}
static className = "T5Transformer";
computeOutputShape(inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[] {
if (Array.isArray(inputShape[0])) {
return inputShape[0];
}
return inputShape;
}
}
tf.serialization.registerClass(_T5Transformer);
export const T5Transformer = (options: T5TransformerLayerArgs) => {
return new _T5Transformer(options);
};