TensorFlow-2.x-Tutorials
TensorFlow-2.x-Tutorials copied to clipboard
关于WGAN-gp源码的两点问题
在看源码的过程中发现了一点小问题
- wgan_train.py源码还是使用了sigmoid再做cross_entro_loss,但是WGAN应该直接返回Discrimintaror的输出logits作为loss
def d_loss_fn(generator, discriminator, batch_z, real_image):
fake_image = generator(batch_z, training=True)
d_fake_score = discriminator(fake_image, training=True)
d_real_score = discriminator(real_image, training=True)
loss = tf.reduce_mean(d_fake_score - d_real_score)
# lambda = 10
gp = gradient_penalty(discriminator, real_image, fake_image) * 10.
loss = loss + gp
return loss, gp
def g_loss_fn(generator, discriminator, batch_z):
fake_image = generator(batch_z, training=True)
d_fake_logits = discriminator(fake_image, training=True)
# loss = celoss_ones(d_fake_logits)
loss = -tf.reduce_mean(d_fake_logits)
return loss
2.按照WGAN的要求改完logits作为loss后,我发现train起来不能收敛,经过反复检查,发现是gradient penalty的计算有些问题,将原有函数如下之后可以很好地收敛:
def gradient_penalty(discriminator, real_image, fake_image):
batchsz = real_image.shape[0]
# dtype caused disconvergence?
t = tf.random.uniform([batchsz, 1, 1, 1], minval=0., maxval=1., dtype=tf.float32)
x_hat = t * real_image + (1. - t) * fake_image
with tf.GradientTape() as tape:
tape.watch(x_hat)
Dx = discriminator(x_hat, training=True)
grads = tape.gradient(Dx, x_hat)
slopes = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
gp = tf.reduce_mean((slopes - 1.) ** 2)
return gp
改进前:train到5W epoch左右就会发生梯度爆炸,导致generator只能产生噪声。
改进后:发挥了WGAN training稳定的特性,目前train了16W个epoch,输出还是可以稳定提升。
其他改进:使用Deconvolution,输出放大仔细看,好像能观察到棋盘状暗纹。可能是Conv_Transpose导致的overlap。如果把discriminator改为upsampling+Conv2D的结构应该可以消除,由于该改进我还在train,具体效果还有待确认