wgan-gp
wgan-gp copied to clipboard
bug in calc_gradient_penalty?
https://github.com/caogang/wgan-gp/blob/ae47a185ed2e938c39cf3eb2f06b32dc1b6a2064/gan_language.py#L159
Right before this line, it should be added "gradients = gradients.view(BATCH_SIZE, -1)"
?
It's OK in gan_mnist.py because real_data
and fake_data
has shape [BATCH_SZIE, OUTPUT_DIM]
, and Discriminator
reshapes input at first.
https://github.com/caogang/wgan-gp/blob/ae47a185ed2e938c39cf3eb2f06b32dc1b6a2064/gan_mnist.py#L100
But, in gan_cifar10.py, Discriminator
requires input tensor to have rank 4. It means gradients
also has rank 4.
https://github.com/caogang/wgan-gp/blob/ae47a185ed2e938c39cf3eb2f06b32dc1b6a2064/gan_cifar10.py#L89-L93
So computation of gradient_penalty
is not correct.
In original(?) implementation Discriminator
requires rank 2 tensor.
https://github.com/igul222/improved_wgan_training/blob/master/gan_cifar.py#L71
In gan_language.py, original Discriminator
takes rank 3 tensor. But norm is computed along two axes.
https://github.com/igul222/improved_wgan_training/blob/master/gan_language.py#L107