Machine-Learning-Collection icon indicating copy to clipboard operation
Machine-Learning-Collection copied to clipboard

I have a question about the gradient propagation of the Discriminator in WGAN-GP.

Open TissueNam opened this issue 4 years ago • 0 comments

In the training process of WGAN-GP (train.py), the following grdient propagation was performed.

        fake = gen(noise)
        critic_real = critic(real).reshape(-1)
        critic_fake = critic(fake).reshape(-1)
        gp = gradient_penalty(critic, real, fake, device=device)
        loss_critic = (
            -(torch.mean(critic_real)-torch.mean(critic_fake)) + LAMBDA_GP * gp
        )
        critic.zero_grad()
        loss_critic.backward(retain_graph=True)
        opt_critic.step()

You used the final loss_critic for gradient propagation. I looked at other people's code additionally. I could see critic_real and critic_fake also doing critic_real.backward() and critic_fake.backward(). What's the difference between this method? And which method would you prefer?

Example) Zeleni9/pytorch-wgan/models /wgan_grdient_penalty.py-https://github.com/Zeleni9/pytorch-wgan

TissueNam avatar Apr 16 '21 05:04 TissueNam