Keras-SRGAN
Keras-SRGAN copied to clipboard
SRGAN doesn't work without pretraining a SRResNet (How to fix)
The GAN losses don't stabilize unless you first pretrain the generator. The network will still end up being able to improve image quality if you don't, but it will be only from the VGG loss and the GAN part will be basically useless (at least in my experience)
The fix is extremely easy. In "get_gan_network" change the compilation to be
gan = Model(gan_input, x) gan.compile(loss='mse', optimizer=optimizer)
and in the training loop just comment out the training of the discriminator. This trains the generator to minimize the MSE between the training inputs and the training targets with no GAN. Once this model is trained, removed these changes and continue training (using VGG perceptual loss + GAN loss and training the descriminator in the training loop)
do you have the code for this? I am currently stuck trying to implement this. It seems as though it is working but when I predict on the generator I am only getting a black image (all zeros), even when I denormalize. I also tried using the MSE optimizer like you have a above but the loss output is negative which should not be the case. Your help is greatly appreciated!!!
All you need to change is get_gan_network() to be - `def get_gan_network(discriminator, shape, generator, optimizer):
gan_input = Input(shape=shape)
x = generator(gan_input)
gan = Model(inputs=gan_input, outputs=x)
gan.compile(loss='mse, optimizer=optimizer)
return gan`
And then remove all of this (the training of the discriminator in the training loop) - ` rand_nums = np.random.randint(0, x_train_hr.shape[0], size=batch_size)
image_batch_hr = x_train_hr[rand_nums]
image_batch_lr = x_train_lr[rand_nums]
generated_images_sr = generator.predict(image_batch_lr)
real_data_Y = np.ones(batch_size) - np.random.random_sample(batch_size)*0.2
fake_data_Y = np.random.random_sample(batch_size)*0.2
discriminator.trainable = True
d_loss_real = discriminator.train_on_batch(image_batch_hr, real_data_Y)
d_loss_fake = discriminator.train_on_batch(generated_images_sr, fake_data_Y)
discriminator_loss = 0.5 * np.add(d_loss_fake, d_loss_real)`
And set the line that trains the generator to just try to turn LR into HR without any gan loss -
gan_loss = gan.train_on_batch(image_batch_lr, image_batch_hr)
thank you so much!!! You are a life saver
All you need to change is get_gan_network() to be - `def get_gan_network(discriminator, shape, generator, optimizer):
gan_input = Input(shape=shape) x = generator(gan_input) gan = Model(inputs=gan_input, outputs=x) gan.compile(loss='mse, optimizer=optimizer) return gan`
And then remove all of this (the training of the discriminator in the training loop) - ` rand_nums = np.random.randint(0, x_train_hr.shape[0], size=batch_size)
image_batch_hr = x_train_hr[rand_nums] image_batch_lr = x_train_lr[rand_nums] generated_images_sr = generator.predict(image_batch_lr) real_data_Y = np.ones(batch_size) - np.random.random_sample(batch_size)*0.2 fake_data_Y = np.random.random_sample(batch_size)*0.2 discriminator.trainable = True d_loss_real = discriminator.train_on_batch(image_batch_hr, real_data_Y) d_loss_fake = discriminator.train_on_batch(generated_images_sr, fake_data_Y) discriminator_loss = 0.5 * np.add(d_loss_fake, d_loss_real)`
And set the line that trains the generator to just try to turn LR into HR without any gan loss -
gan_loss = gan.train_on_batch(image_batch_lr, image_batch_hr)
Thanks for the code. So you are basically training a non-GAN this way? Once this part is trained, do you save and re-load the weights and then continue training the GAN model?
Thanks for the code. So you are basically training a non-GAN this way? Once this part is trained, do you save and re-load the weights and then continue training the GAN model?
Yes. Fast.ai calls this "NoGan training" More info - https://www.fast.ai/2019/05/03/decrappify/ You are basically just training a Gan but you pre train the generator so that it is already making decent images from the start. It usually keeps things more stabilized.