Machine-Learning-Collection
Machine-Learning-Collection copied to clipboard
I have a question about the gradient propagation of the Discriminator in WGAN-GP.
In the training process of WGAN-GP (train.py), the following grdient propagation was performed.
fake = gen(noise)
critic_real = critic(real).reshape(-1)
critic_fake = critic(fake).reshape(-1)
gp = gradient_penalty(critic, real, fake, device=device)
loss_critic = (
-(torch.mean(critic_real)-torch.mean(critic_fake)) + LAMBDA_GP * gp
)
critic.zero_grad()
loss_critic.backward(retain_graph=True)
opt_critic.step()
You used the final loss_critic for gradient propagation. I looked at other people's code additionally. I could see critic_real and critic_fake also doing critic_real.backward() and critic_fake.backward(). What's the difference between this method? And which method would you prefer?
Example) Zeleni9/pytorch-wgan/models /wgan_grdient_penalty.py-https://github.com/Zeleni9/pytorch-wgan