Keras-SRGAN icon indicating copy to clipboard operation
Keras-SRGAN copied to clipboard

SRGAN doesn't work without pretraining a SRResNet (How to fix)

Open neilthefrobot opened this issue 3 years ago • 5 comments

The GAN losses don't stabilize unless you first pretrain the generator. The network will still end up being able to improve image quality if you don't, but it will be only from the VGG loss and the GAN part will be basically useless (at least in my experience) The fix is extremely easy. In "get_gan_network" change the compilation to be gan = Model(gan_input, x) gan.compile(loss='mse', optimizer=optimizer)

and in the training loop just comment out the training of the discriminator. This trains the generator to minimize the MSE between the training inputs and the training targets with no GAN. Once this model is trained, removed these changes and continue training (using VGG perceptual loss + GAN loss and training the descriminator in the training loop)

neilthefrobot avatar Mar 04 '21 16:03 neilthefrobot

do you have the code for this? I am currently stuck trying to implement this. It seems as though it is working but when I predict on the generator I am only getting a black image (all zeros), even when I denormalize. I also tried using the MSE optimizer like you have a above but the loss output is negative which should not be the case. Your help is greatly appreciated!!!

HeeebsInc avatar Mar 07 '21 22:03 HeeebsInc

All you need to change is get_gan_network() to be - `def get_gan_network(discriminator, shape, generator, optimizer):

gan_input = Input(shape=shape)  

x = generator(gan_input)  

gan = Model(inputs=gan_input, outputs=x)  

gan.compile(loss='mse, optimizer=optimizer)  

return gan`

And then remove all of this (the training of the discriminator in the training loop) - ` rand_nums = np.random.randint(0, x_train_hr.shape[0], size=batch_size)

        image_batch_hr = x_train_hr[rand_nums]
        image_batch_lr = x_train_lr[rand_nums]
        generated_images_sr = generator.predict(image_batch_lr)

        real_data_Y = np.ones(batch_size) - np.random.random_sample(batch_size)*0.2
        fake_data_Y = np.random.random_sample(batch_size)*0.2
        
        discriminator.trainable = True
        
        d_loss_real = discriminator.train_on_batch(image_batch_hr, real_data_Y)
        d_loss_fake = discriminator.train_on_batch(generated_images_sr, fake_data_Y)
        discriminator_loss = 0.5 * np.add(d_loss_fake, d_loss_real)`

And set the line that trains the generator to just try to turn LR into HR without any gan loss - gan_loss = gan.train_on_batch(image_batch_lr, image_batch_hr)

neilthefrobot avatar Mar 08 '21 15:03 neilthefrobot

thank you so much!!! You are a life saver

HeeebsInc avatar Mar 08 '21 16:03 HeeebsInc

All you need to change is get_gan_network() to be - `def get_gan_network(discriminator, shape, generator, optimizer):

gan_input = Input(shape=shape)  

x = generator(gan_input)  

gan = Model(inputs=gan_input, outputs=x)  

gan.compile(loss='mse, optimizer=optimizer)  

return gan`

And then remove all of this (the training of the discriminator in the training loop) - ` rand_nums = np.random.randint(0, x_train_hr.shape[0], size=batch_size)

        image_batch_hr = x_train_hr[rand_nums]
        image_batch_lr = x_train_lr[rand_nums]
        generated_images_sr = generator.predict(image_batch_lr)

        real_data_Y = np.ones(batch_size) - np.random.random_sample(batch_size)*0.2
        fake_data_Y = np.random.random_sample(batch_size)*0.2
        
        discriminator.trainable = True
        
        d_loss_real = discriminator.train_on_batch(image_batch_hr, real_data_Y)
        d_loss_fake = discriminator.train_on_batch(generated_images_sr, fake_data_Y)
        discriminator_loss = 0.5 * np.add(d_loss_fake, d_loss_real)`

And set the line that trains the generator to just try to turn LR into HR without any gan loss - gan_loss = gan.train_on_batch(image_batch_lr, image_batch_hr)

Thanks for the code. So you are basically training a non-GAN this way? Once this part is trained, do you save and re-load the weights and then continue training the GAN model?

siweic0818 avatar Mar 11 '21 08:03 siweic0818

Thanks for the code. So you are basically training a non-GAN this way? Once this part is trained, do you save and re-load the weights and then continue training the GAN model?

Yes. Fast.ai calls this "NoGan training" More info - https://www.fast.ai/2019/05/03/decrappify/ You are basically just training a Gan but you pre train the generator so that it is already making decent images from the start. It usually keeps things more stabilized.

neilthefrobot avatar Mar 11 '21 23:03 neilthefrobot