PyTorch-StudioGAN icon indicating copy to clipboard operation
PyTorch-StudioGAN copied to clipboard

Gradient penalty "interpolates" term

Open rl-max opened this issue 2 years ago • 1 comments

Hi,

This is 'cal_grad_penalty' function in /src/utils/losses.py

def cal_grad_penalty(real_images, real_labels, fake_images, discriminator, device):
    batch_size, c, h, w = real_images.shape
    alpha = torch.rand(batch_size, 1)
    alpha = alpha.expand(batch_size, real_images.nelement() // batch_size).contiguous().view(batch_size, c, h, w)
    alpha = alpha.to(device)

    real_images = real_images.to(device)
    interpolates = alpha * real_images + ((1 - alpha) * fake_images)
    interpolates = interpolates.to(device)
    interpolates = autograd.Variable(interpolates, requires_grad=True)
    fake_dict = discriminator(interpolates, real_labels, eval=False)
    grads = cal_deriv(inputs=interpolates, outputs=fake_dict["adv_output"], device=device)
    grads = grads.view(grads.size(0), -1)

    grad_penalty = ((grads.norm(2, dim=1) - 1)**2).mean() + interpolates[:,0,0,0].mean()*0
    return grad_penalty

In the last line, grad_penalty = ((grads.norm(2, dim=1) - 1)**2).mean() + interpolates[:,0,0,0].mean()*0, I wanted to know what additive term + interpolates[:,0,0,0].mean()*0 means. Since it's zero-multiplicated, I think it has actually no effect for code.

I'll be waiting for your answer

Thank you!

rl-max avatar Nov 03 '22 00:11 rl-max

The implementation has no effect on the code and was introduced to address a PyTorch bug that arises during DDP training with R1 regularization.

mingukkang avatar Apr 18 '23 03:04 mingukkang