How to implement metrics with a signature other than `fn(y_true, y_pred)`?
To add a loss and metrics to a model, I can add them to model.compile(loss=..., metrics=...), provided that they have the signature fn(y_true, y_pred), see the docs. If I have a loss with a different signature, e.g. because it also depends on some intermediate states of the model, then I can instead add it using add_loss in the model's call method. The same is no longer true for metrics, because the corresponding method add_metric became deprecated in Keras 3.
How am I supposed to implement metrics that do not follow the signature fn(y_true, y_pred), but require values from hidden layers inside the model?
To add some background here: I am implementing a variational autoencoder with total_loss=gamma*reconstruction_loss+(1-gamma)*kl_loss and I want both reconstruction_loss and kl_loss to be logged separately during training. In Keras 2, I could do this by adding both losses additionally as metrics as follows:
class VariationalAutoEncoder(keras.Model):
def __init__(self, encoder, decoder, gamma=0.999, **kwargs):
super().__init__(**kwargs)
self.encoder = encoder
self.decoder = decoder
self.gamma = gamma
def call(self, inputs):
mean, log_var, z = self.encoder(inputs)
reconstructed = self.decoder(z)
reconstruction_loss = keras.losses.MeanSquaredError()(input, reconstructed)
kullback_leibler_loss = -0.5*tf.reduce_mean(log_var-tf.square(mean)-tf.exp(log_var)+1)
self.add_loss(self.gamma*reconstruction_loss+(1-self.gamma)*kullback_leibler_loss)
# Additionally add losses as metrics to track them separately
self.add_metric(reconstruction_loss, name='reconstruction_loss') # <- DEPRECATED IN KERAS 3
self.add_metric(kullback_leibler_loss, name='kl_loss') # <- DEPRECATED IN KERAS 3
return reconstructed
This no longer works in Keras 3 because add_metric is deprecated.
I am aware of the Keras example implementing a VAE. The approach there is to define additional trackers for the losses and to overwrite train_step to update them. Messing with the training logic only to implement further metrics seems like an overkill to me, though, and I was hoping for a cleaner approach in the style of add_metric. I guess what I am asking here is: can we have add_metric back?
This can be done using the new structured loss feature:
import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import numpy as np
import keras
from keras import layers
from keras import ops
class Sampling(layers.Layer):
"""Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.seed_generator = keras.random.SeedGenerator(1337)
def call(self, inputs):
z_mean, z_log_var = inputs
batch = ops.shape(z_mean)[0]
dim = ops.shape(z_mean)[1]
epsilon = keras.random.normal(
shape=(batch, dim), seed=self.seed_generator
)
return z_mean + ops.exp(0.5 * z_log_var) * epsilon
latent_dim = 2
encoder_inputs = keras.Input(shape=(28, 28, 1))
x = layers.Conv2D(32, 3, activation="relu", strides=2, padding="same")(encoder_inputs)
x = layers.Conv2D(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Flatten()(x)
x = layers.Dense(16, activation="relu")(x)
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
z = Sampling()([z_mean, z_log_var])
encoder = keras.Model(
encoder_inputs,
{"z_mean": z_mean, "z_log_var": z_log_var, "z": z},
name="encoder",
)
encoder.summary()
latent_inputs = keras.Input(shape=(latent_dim,))
x = layers.Dense(7 * 7 * 64, activation="relu")(latent_inputs)
x = layers.Reshape((7, 7, 64))(x)
x = layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same")(x)
decoder_outputs = layers.Conv2DTranspose(1, 3, activation="sigmoid", padding="same")(x)
decoder = keras.Model(latent_inputs, decoder_outputs, name="decoder")
decoder.summary()
vae = keras.Model(
encoder.input,
{"encoder": encoder.output, "decoder": decoder(encoder.output["z"])},
name="vae",
)
def encoder_loss_fn(y_true, y_pred):
z_mean, z_log_var, z = y_pred["z_mean"], y_pred["z_log_var"], y_pred["z"]
kl_loss = -0.5 * (1 + z_log_var - ops.square(z_mean) - ops.exp(z_log_var))
kl_loss = ops.mean(ops.sum(kl_loss, axis=1))
return kl_loss
gamma = 0.5
vae.compile(
optimizer="adam",
loss={"encoder": encoder_loss_fn, "decoder": "binary_crossentropy"},
loss_weights={"encoder": 1 - gamma, "decoder": gamma},
)
(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
mnist_digits = np.concatenate([x_train, x_test], axis=0)
mnist_digits = np.expand_dims(mnist_digits, -1).astype("float32") / 255
labels = {
"encoder": keras.tree.map_structure(
# dummy data, not used in the loss
lambda _: np.ones((mnist_digits.shape[0], latent_dim)), encoder.output
),
"decoder": mnist_digits,
}
vae.fit(x=mnist_digits, y=labels, epochs=30, batch_size=128)
Epoch 1/30
547/547 ━━━━━━━━━━━━━━━━━━━━ 16s 27ms/step - decoder_loss: 0.1729 - encoder_loss: 2.5852e-05 - loss: 0.1729
You merged a solution 2 days before I posted this issue, this must be this predictive maintenance I heard so much from ;-) Thank you for the quick response, I'll take a look.
This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.
This issue was closed because it has been inactive for 28 days. Please reopen if you'd like to work on this further.