handson-ml2 icon indicating copy to clipboard operation
handson-ml2 copied to clipboard

[QUESTION] Plotting loss with the Deep Convolutional GAN

Open Reisa14 opened this issue 3 years ago • 0 comments

Describe what is unclear to you When creating autoencoders, we were using fit() to produce the history, which could then be used to plot the loss across the training and validation periods, such as on p. 590:

history = variational_ae.fit(X_train, X_train, epochs=25, batch_size=128, 
                             validation_data=(X_valid, X_valid))

However, when creating the GANs and deep convolutional GANs, we do not use fit(), we use the custom train_gan function:

def train_gan(gan, dataset, batch_size, codings_size, n_epochs=20):
    generator, discriminator = gan.layers
    for epoch in range(n_epochs):
        print("Epoch {}/{}".format(epoch + 1, n_epochs))
        for X_batch in dataset:
            # phase 1 - training the discriminator
            X_batch = tf.cast(X_batch, tf.float32)
            noise = tf.random.normal(shape=[batch_size, codings_size])
            generated_images = generator(noise)
            X_fake_and_real = tf.concat([generated_images, X_batch], axis=0)
            y1 = tf.constant([[0.]] * batch_size + [[1.]] * batch_size)
            discriminator.trainable = True
            discriminator.train_on_batch(X_fake_and_real, y1)
            # phase 2 - training the generator
            noise = tf.random.normal(shape=[batch_size, codings_size])
            y2 = tf.constant([[1.]] * batch_size)
            discriminator.trainable = False
            gan.train_on_batch(noise, y2)
        plot_multiple_images(generated_images, 8)
        plt.show()

If we wanted to plot the loss of both the discriminator and generator across all epochs in the example on page 599, how would we go about this?

Thanks!

Reisa14 avatar May 11 '22 10:05 Reisa14