A bug causing gradient to NaN in WGAN-gp example
Hi, I want to report a bug and corresponding fix to your team. I believe this fix could help others who use WGAN-gp algorithm to train GANs.
Bug and Reason
It appears in the "code example" part in Keras documentation, and the URL is "https://keras.io/examples/generative/wgan_gp/" Here is the core code, which is used to calculate the gradient penalty (the gp in WGAN-gp):
# 3. Calculate the norm of the gradients.
norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
When back-propagation is calculating, tf.sqrt is non-differentiable at 0. If the gradients of last batch are all zero, this case could happen and the norm will be Inf.
Why is tf.sqrt non-differentiable at 0?
Because x can not be 0 in equation $y^{'}=x^{-1/2}$. It appears in the denominator.
Consequence
If norm is Inf, the gp will also be Inf. gp is calculated by below code.
gp = tf.reduce_mean((norm - 1.0) ** 2)
Thus, d_loss changes to Inf. d_loss is calculated by below code.
# Add the gradient penalty to the original discriminator loss
d_loss = d_cost + gp * self.gp_weight
Finally, d_gradient changes to NaN. d_gradient is calculated by below code. Following trains will be failed.
# Get the gradients w.r.t the discriminator loss
d_gradient = tape.gradient(d_loss, self.discriminator.trainable_variables)
Fix
Change
# 3. Calculate the norm of the gradients.
norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
to
# 3. Calculate the norm of the gradients.
norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3])+1e-12)
MWE 1:
x = tf.Variable(0.) ### x is zero
with tf.GradientTape() as g: g.watch(x);y = (tf.sqrt(x)) ### watch x and define y to be square root of x
dy_dx = g.gradient(y, x) ## compute gradient with respect to x
print(dy_dx). ## results in inf
MWE 2:
x = tf.Variable(0.) ### x is zero
with tf.GradientTape() as g: g.watch(x);y = (tf.sqrt(x) + 1e-12) ### add delta to
dy_dx = g.gradient(y, x) ## compute gradient with respect to x
print(dy_dx). ## results in inf
MWE 3:
x = tf.Variable(0.) ### x is zero
with tf.GradientTape() as g: g.watch(x);y = (tf.sqrt(x+ 0.1) ) ### watch x + delta and define y to be square root of x
dy_dx = g.gradient(y, x) ## compute gradient with respect to x
print(dy_dx). ## this works , if delta too large , it might result in gradient explosion hence 0.1