PyTorch-GAN
PyTorch-GAN copied to clipboard
Use of detach
Hello, I'm currently learning about PyTorch and GAN, I want to ask about this particular lines from the WGAN implementation here.
# Configure input
real_imgs = Variable(imgs.type(Tensor))
# ---------------------
# Train Discriminator
# ---------------------
optimizer_D.zero_grad()
# Sample noise as generator input
z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))
# Generate a batch of images
fake_imgs = generator(z).detach()
# Adversarial loss
loss_D = -torch.mean(discriminator(real_imgs)) + torch.mean(discriminator(fake_imgs))
loss_D.backward()
optimizer_D.step()
Why is the fake_imgs
have .detach()
function while the real_imgs
doesn't?
What would happen if I put .detach()
on the real_imgs
as well, will that mess with the discriminator update?
Thank you.
It is due to computational graph, when fake image is generated computational graph contain all "events" from latent variables to final fake image. we don't want to make a computation graph on top of it and then backpropagate to entire graph( G+D) but only D. so you detach or make a separate copy of fake image and do computation on it. real image is input also do not required grad and hence doesn't play a role in backpropagation ,etc
It is due to computational graph, when fake image is generated computational graph contain all "events" from latent variables to final fake image. we don't want to make a computation graph on top of it and then backpropagate to entire graph( G+D) but only D. so you detach or make a separate copy of fake image and do computation on it. real image is input also do not required grad and hence doesn't play a role in backpropagation ,etc
Thank you for your response, I see. So, it doesn't matter whether I put detach on the real_imgs or not in this case, right?
Whether I do this real_imgs = Variable(imgs.type(Tensor))
or real_imgs = Variable(imgs.type(Tensor)).detach()
in this case it will give out the same result?
Also, it seems I forgot to specify where the code lines came from. It's from the WGAN implementation at the discriminator training part.