keras-io icon indicating copy to clipboard operation
keras-io copied to clipboard

A bug causing gradient to NaN in WGAN-gp example

Open Sneaker001333 opened this issue 3 years ago • 2 comments

Hi, I want to report a bug and corresponding fix to your team. I believe this fix could help others who use WGAN-gp algorithm to train GANs.

Bug and Reason

It appears in the "code example" part in Keras documentation, and the URL is "https://keras.io/examples/generative/wgan_gp/" Here is the core code, which is used to calculate the gradient penalty (the gp in WGAN-gp):

# 3. Calculate the norm of the gradients.
norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))

When back-propagation is calculating, tf.sqrt is non-differentiable at 0. If the gradients of last batch are all zero, this case could happen and the norm will be Inf.

Why is tf.sqrt non-differentiable at 0? Because x can not be 0 in equation $y^{'}=x^{-1/2}$. It appears in the denominator.

Consequence

If norm is Inf, the gp will also be Inf. gp is calculated by below code.

gp = tf.reduce_mean((norm - 1.0) ** 2)

Thus, d_loss changes to Inf. d_loss is calculated by below code.

# Add the gradient penalty to the original discriminator loss
d_loss = d_cost + gp * self.gp_weight

Finally, d_gradient changes to NaN. d_gradient is calculated by below code. Following trains will be failed.

# Get the gradients w.r.t the discriminator loss
d_gradient = tape.gradient(d_loss, self.discriminator.trainable_variables)

Fix

Change

# 3. Calculate the norm of the gradients.
norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))

to

# 3. Calculate the norm of the gradients.
norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3])+1e-12)

Sneaker001333 avatar Jun 10 '22 23:06 Sneaker001333

MWE 1: x = tf.Variable(0.) ### x is zero with tf.GradientTape() as g: g.watch(x);y = (tf.sqrt(x)) ### watch x and define y to be square root of x

dy_dx = g.gradient(y, x) ## compute gradient with respect to x

print(dy_dx). ## results in inf

MWE 2: x = tf.Variable(0.) ### x is zero with tf.GradientTape() as g: g.watch(x);y = (tf.sqrt(x) + 1e-12) ### add delta to

dy_dx = g.gradient(y, x) ## compute gradient with respect to x print(dy_dx). ## results in inf

MWE 3: x = tf.Variable(0.) ### x is zero with tf.GradientTape() as g: g.watch(x);y = (tf.sqrt(x+ 0.1) ) ### watch x + delta and define y to be square root of x

dy_dx = g.gradient(y, x) ## compute gradient with respect to x print(dy_dx). ## this works , if delta too large , it might result in gradient explosion hence 0.1

adderbyte avatar Dec 15 '22 18:12 adderbyte