esrgan-tf2 icon indicating copy to clipboard operation
esrgan-tf2 copied to clipboard

update the gradient do not need freeze one of network?

Open cl886699 opened this issue 4 years ago • 2 comments

when update the gradient , did not freeze g or d in code ` @tf.function def train_step(lr, hr): with tf.GradientTape(persistent=True) as tape: sr = generator(lr, training=True) hr_output = discriminator(hr, training=True) sr_output = discriminator(sr, training=True)

        losses_G = {}
        losses_D = {}
        losses_G['reg'] = tf.reduce_sum(generator.losses)
        losses_D['reg'] = tf.reduce_sum(discriminator.losses)
        losses_G['pixel'] = cfg['w_pixel'] * pixel_loss_fn(hr, sr)
        losses_G['feature'] = cfg['w_feature'] * fea_loss_fn(hr, sr)
        losses_G['gan'] = cfg['w_gan'] * gen_loss_fn(hr_output, sr_output)
        losses_D['gan'] = dis_loss_fn(hr_output, sr_output)
        total_loss_G = tf.add_n([l for l in losses_G.values()])
        total_loss_D = tf.add_n([l for l in losses_D.values()])
    grads_G = tape.gradient(
        total_loss_G, generator.trainable_variables)
    grads_D = tape.gradient(
        total_loss_D, discriminator.trainable_variables)
    optimizer_G.apply_gradients(
        zip(grads_G, generator.trainable_variables))
    optimizer_D.apply_gradients(
        zip(grads_D, discriminator.trainable_variables))

    return total_loss_G, total_loss_D, losses_G, losses_D

`

cl886699 avatar Oct 30 '20 09:10 cl886699

Does this mean the loss will always be NAN with this code?

taoyu17 avatar Oct 30 '20 22:10 taoyu17

Does this mean the loss will always be NAN with this code?

no it is also can convergence

cl886699 avatar Oct 30 '20 23:10 cl886699