Diffusion-GAN
Diffusion-GAN copied to clipboard
Use Diffusion-GAN in Other GAN Architecture
Hello @Zhendong-Wang and Team,
I would like to firstly say that it's a great work! Thank you for sharing the code. I am trying to use Diffusion-GAN in a GAN architecture for image enhancement. Can you please help me by letting me know how do I use the three steps mentioned for Simple Plug-in by you in the readme in the below code -
for epoch in range(num_epochs):
for n_batch, (blur_batch, clean_batch) in enumerate(data_loader):
real_data = clean_batch.float().cuda()
noised_data = blur_batch.float().cuda()
# 1. Train Discriminator
# Generate fake data
fake_data = generator(noised_data)
# Reset gradients
d_optimizer.zero_grad()
# 1.1 Train on Real Data
prediction_real = discriminator(real_data, noised_data)
# Calculate error and backpropagate
real_data_target = torch.ones_like(prediction_real)
loss_real = loss1(prediction_real, real_data_target)
# 1.2 Train on Fake Data, you would need to add one more component
prediction_fake = discriminator(fake_data, noised_data)
# Calculate error and backpropagate
fake_data_target = torch.zeros_like(prediction_real)
loss_fake = loss1(prediction_fake, fake_data_target)
loss_d = (loss_real + loss_fake)/2
loss_d.backward(retain_graph=True)
# 1.3 Update weights with gradients
d_optimizer.step()
# 2. Train Generator
g_optimizer.zero_grad()
# Sample noise and generate fake data
prediction = discriminator(fake_data, real_data)
# Calculate error and backpropagate
real_data_target = torch.ones_like(prediction)
#import pdb; pdb.set_trace();
loss_g1 = loss1(prediction, real_data_target)
loss_g2 = loss1(fake_data, real_data)*500
loss_g = loss_g1 + loss_g2
loss_g.backward()
# Update weights with gradients
g_optimizer.step()
# Log error
logger.log(loss_d, loss_g, epoch, n_batch, num_batches)
# Display Progress
if (n_batch) % 100 == 0:
display.clear_output(True)
# Display Images
test_images = vectors_to_images(generator(test_noise())).data.cpu()
logger.log_images(test_images, num_test_samples, epoch, n_batch, num_batches);
# Display status Logs
logger.display_status(
epoch, num_epochs, n_batch, num_batches,
loss_d, loss_g, prediction_real, prediction_fake
)
# Model Checkpoints
logger.save_models(generator, discriminator, epoch)
Thank you so much :)
Excuse me, have you successfully used diffusion-gan? If successful, can you share your experience?
@RisabBiswas, @someonegirl, Were you able to use Diffusion-GAN in other GAN architectures? Can you please share your experience?