PyTorch-StudioGAN
PyTorch-StudioGAN copied to clipboard
Gradient penalty "interpolates" term
Hi,
This is 'cal_grad_penalty' function in /src/utils/losses.py
def cal_grad_penalty(real_images, real_labels, fake_images, discriminator, device):
batch_size, c, h, w = real_images.shape
alpha = torch.rand(batch_size, 1)
alpha = alpha.expand(batch_size, real_images.nelement() // batch_size).contiguous().view(batch_size, c, h, w)
alpha = alpha.to(device)
real_images = real_images.to(device)
interpolates = alpha * real_images + ((1 - alpha) * fake_images)
interpolates = interpolates.to(device)
interpolates = autograd.Variable(interpolates, requires_grad=True)
fake_dict = discriminator(interpolates, real_labels, eval=False)
grads = cal_deriv(inputs=interpolates, outputs=fake_dict["adv_output"], device=device)
grads = grads.view(grads.size(0), -1)
grad_penalty = ((grads.norm(2, dim=1) - 1)**2).mean() + interpolates[:,0,0,0].mean()*0
return grad_penalty
In the last line, grad_penalty = ((grads.norm(2, dim=1) - 1)**2).mean() + interpolates[:,0,0,0].mean()*0
, I wanted to know what additive term + interpolates[:,0,0,0].mean()*0
means. Since it's zero-multiplicated, I think it has actually no effect for code.
I'll be waiting for your answer
Thank you!
The implementation has no effect on the code and was introduced to address a PyTorch bug that arises during DDP training with R1 regularization.