keras icon indicating copy to clipboard operation
keras copied to clipboard

How to implement metrics with a signature other than `fn(y_true, y_pred)`?

Open early-stopper opened this issue 1 year ago • 3 comments

To add a loss and metrics to a model, I can add them to model.compile(loss=..., metrics=...), provided that they have the signature fn(y_true, y_pred), see the docs. If I have a loss with a different signature, e.g. because it also depends on some intermediate states of the model, then I can instead add it using add_loss in the model's call method. The same is no longer true for metrics, because the corresponding method add_metric became deprecated in Keras 3.

How am I supposed to implement metrics that do not follow the signature fn(y_true, y_pred), but require values from hidden layers inside the model?

early-stopper avatar Oct 22 '24 10:10 early-stopper

To add some background here: I am implementing a variational autoencoder with total_loss=gamma*reconstruction_loss+(1-gamma)*kl_loss and I want both reconstruction_loss and kl_loss to be logged separately during training. In Keras 2, I could do this by adding both losses additionally as metrics as follows:

class VariationalAutoEncoder(keras.Model):
    def __init__(self, encoder, decoder, gamma=0.999, **kwargs):
        super().__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.gamma = gamma

    def call(self, inputs):
        mean, log_var, z = self.encoder(inputs)
        reconstructed = self.decoder(z)
        
        reconstruction_loss = keras.losses.MeanSquaredError()(input, reconstructed)
        kullback_leibler_loss = -0.5*tf.reduce_mean(log_var-tf.square(mean)-tf.exp(log_var)+1)
        self.add_loss(self.gamma*reconstruction_loss+(1-self.gamma)*kullback_leibler_loss)
        
        # Additionally add losses as metrics to track them separately
        self.add_metric(reconstruction_loss, name='reconstruction_loss')  # <- DEPRECATED IN KERAS 3
        self.add_metric(kullback_leibler_loss, name='kl_loss')  # <- DEPRECATED IN KERAS 3

        return reconstructed

This no longer works in Keras 3 because add_metric is deprecated.

I am aware of the Keras example implementing a VAE. The approach there is to define additional trackers for the losses and to overwrite train_step to update them. Messing with the training logic only to implement further metrics seems like an overkill to me, though, and I was hoping for a cleaner approach in the style of add_metric. I guess what I am asking here is: can we have add_metric back?

early-stopper avatar Oct 22 '24 10:10 early-stopper

This can be done using the new structured loss feature:

import os

os.environ["KERAS_BACKEND"] = "tensorflow"

import numpy as np

import keras
from keras import layers
from keras import ops


class Sampling(layers.Layer):
    """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.seed_generator = keras.random.SeedGenerator(1337)

    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = ops.shape(z_mean)[0]
        dim = ops.shape(z_mean)[1]
        epsilon = keras.random.normal(
            shape=(batch, dim), seed=self.seed_generator
        )
        return z_mean + ops.exp(0.5 * z_log_var) * epsilon


latent_dim = 2

encoder_inputs = keras.Input(shape=(28, 28, 1))
x = layers.Conv2D(32, 3, activation="relu", strides=2, padding="same")(encoder_inputs)
x = layers.Conv2D(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Flatten()(x)
x = layers.Dense(16, activation="relu")(x)
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
z = Sampling()([z_mean, z_log_var])
encoder = keras.Model(
    encoder_inputs,
    {"z_mean": z_mean, "z_log_var": z_log_var, "z": z},
    name="encoder",
)
encoder.summary()

latent_inputs = keras.Input(shape=(latent_dim,))
x = layers.Dense(7 * 7 * 64, activation="relu")(latent_inputs)
x = layers.Reshape((7, 7, 64))(x)
x = layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same")(x)
decoder_outputs = layers.Conv2DTranspose(1, 3, activation="sigmoid", padding="same")(x)
decoder = keras.Model(latent_inputs, decoder_outputs, name="decoder")
decoder.summary()


vae = keras.Model(
    encoder.input,
    {"encoder": encoder.output, "decoder": decoder(encoder.output["z"])},
    name="vae",
)

def encoder_loss_fn(y_true, y_pred):
  z_mean, z_log_var, z = y_pred["z_mean"], y_pred["z_log_var"], y_pred["z"]
  kl_loss = -0.5 * (1 + z_log_var - ops.square(z_mean) - ops.exp(z_log_var))
  kl_loss = ops.mean(ops.sum(kl_loss, axis=1))
  return kl_loss

gamma = 0.5
vae.compile(
    optimizer="adam",
    loss={"encoder": encoder_loss_fn, "decoder": "binary_crossentropy"},
    loss_weights={"encoder": 1 - gamma, "decoder": gamma},
)

(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()

mnist_digits = np.concatenate([x_train, x_test], axis=0)
mnist_digits = np.expand_dims(mnist_digits, -1).astype("float32") / 255

labels = {
    "encoder": keras.tree.map_structure(
        # dummy data, not used in the loss
        lambda _: np.ones((mnist_digits.shape[0], latent_dim)), encoder.output
    ),
    "decoder": mnist_digits,
}

vae.fit(x=mnist_digits, y=labels, epochs=30, batch_size=128)

Epoch 1/30
547/547 ━━━━━━━━━━━━━━━━━━━━ 16s 27ms/step - decoder_loss: 0.1729 - encoder_loss: 2.5852e-05 - loss: 0.1729

nicolaspi avatar Oct 23 '24 07:10 nicolaspi

You merged a solution 2 days before I posted this issue, this must be this predictive maintenance I heard so much from ;-) Thank you for the quick response, I'll take a look.

early-stopper avatar Oct 23 '24 07:10 early-stopper

This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.

github-actions[bot] avatar Dec 04 '24 02:12 github-actions[bot]

This issue was closed because it has been inactive for 28 days. Please reopen if you'd like to work on this further.

github-actions[bot] avatar Dec 19 '24 02:12 github-actions[bot]

Are you satisfied with the resolution of your issue? Yes No

google-ml-butler[bot] avatar Dec 19 '24 02:12 google-ml-butler[bot]