esrgan-tf2
esrgan-tf2 copied to clipboard
update the gradient do not need freeze one of network?
when update the gradient , did not freeze g or d in code ` @tf.function def train_step(lr, hr): with tf.GradientTape(persistent=True) as tape: sr = generator(lr, training=True) hr_output = discriminator(hr, training=True) sr_output = discriminator(sr, training=True)
losses_G = {}
losses_D = {}
losses_G['reg'] = tf.reduce_sum(generator.losses)
losses_D['reg'] = tf.reduce_sum(discriminator.losses)
losses_G['pixel'] = cfg['w_pixel'] * pixel_loss_fn(hr, sr)
losses_G['feature'] = cfg['w_feature'] * fea_loss_fn(hr, sr)
losses_G['gan'] = cfg['w_gan'] * gen_loss_fn(hr_output, sr_output)
losses_D['gan'] = dis_loss_fn(hr_output, sr_output)
total_loss_G = tf.add_n([l for l in losses_G.values()])
total_loss_D = tf.add_n([l for l in losses_D.values()])
grads_G = tape.gradient(
total_loss_G, generator.trainable_variables)
grads_D = tape.gradient(
total_loss_D, discriminator.trainable_variables)
optimizer_G.apply_gradients(
zip(grads_G, generator.trainable_variables))
optimizer_D.apply_gradients(
zip(grads_D, discriminator.trainable_variables))
return total_loss_G, total_loss_D, losses_G, losses_D
`
Does this mean the loss will always be NAN with this code?
Does this mean the loss will always be NAN with this code?
no it is also can convergence