CycleGAN
CycleGAN copied to clipboard
Holes showing up while training
Hi, so I've been trying to replicate your paper by creating a PyTorch model from scratch and training it on the original vangogh2photo dataset provided by Berkeley. Admittedly, it's for fun and not for any research, but I still hate it when it doesn't work out. So, this is the architecture of the model I've made:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import random
'''This is a memory storage that stores 50 previously created images.
This is in accordance with the paper that introduced CycleGAN, Unpaired Image to Image translation.'''
class ReplayBuffer:
def __init__(self, max_size=50):
assert max_size > 0, "Empty buffer."
self.max_size = max_size
self.data = []
def push_and_pop(self, data):
to_return = []
for element in data.data:
element = torch.unsqueeze(element, 0)
if len(self.data) < self.max_size:
self.data.append(element)
to_return.append(element)
else:
# Returns newly added image with a probability of 0.5.
if random.uniform(0, 1) > 0.5:
i = random.randint(0, self.max_size - 1)
to_return.append(self.data[i].clone())
self.data[
i
] = element # replaces the older image with the newly generated image.
else:
# Otherwise, it sends an older generated image and
to_return.append(element)
return Variable(torch.cat(to_return))
'''Linear learning rate scheduler.'''
class LambdaLR:
def __init__(self, n_epochs, offset, decay_start_epoch):
if (n_epochs - decay_start_epoch) < 0:
raise Exception("Decay should start before training ends. Change decay_start_epoch to a value less than {}.".format(n_epochs))
self.n_epochs = n_epochs
self.offset = offset
self.decay_start_epoch = decay_start_epoch
def step(self, epoch):
return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch) / (self.n_epochs - self.decay_start_epoch)
'''Single Residual Block. InstanceNorm2d produces blob artefacts. Consider changing it to modulated convolutions later.
Currently using augmentation and a low number of epochs to stop Generator from producing artefacts.'''
class ResNetBlock(nn.Module):
def __init__(self, channels):
super(ResNetBlock, self).__init__()
self.conv_block = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(channels, channels, kernel_size=3, padding=0, bias=True),
nn.InstanceNorm2d(channels),
nn.ReLU(inplace=True),
nn.ReflectionPad2d(1),
nn.Conv2d(channels, channels, kernel_size=3, padding=0, bias=True),
nn.InstanceNorm2d(channels)
)
def forward(self, x):
return x + self.conv_block(x)
class GeneratorResNet(nn.Module):
def __init__(self, input_channels, output_channels, num_resnet_blocks=9):
super(GeneratorResNet, self).__init__()
# Initial convolutional layer
self.initial_conv = nn.Sequential(
nn.ReflectionPad2d(3),
nn.Conv2d(input_channels, 64, kernel_size=7, padding=0, bias=True),
nn.InstanceNorm2d(64),
nn.ReLU(inplace=True)
)
# Downsampling layers
self.downsampling_1 = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=True),
nn.InstanceNorm2d(128),
nn.ReLU(inplace=True)
)
self.downsampling_2 = nn.Sequential(
nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1, bias=True),
nn.InstanceNorm2d(256),
nn.ReLU(inplace=True)
)
# Residual layers
self.residual_layers = nn.Sequential(
*[ResNetBlock(256) for _ in range(num_resnet_blocks)]
)
# Upsampling layers
self.upsampling_1 = nn.Sequential(
nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True),
nn.InstanceNorm2d(128),
nn.ReLU(inplace=True)
)
self.upsampling_2 = nn.Sequential(
nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True),
nn.InstanceNorm2d(64),
nn.ReLU(inplace=True)
)
# Final convolutional layer
self.final_conv = nn.Sequential(
nn.Reflectio
data:image/s3,"s3://crabby-images/cc6a3/cc6a3f5e839c5d902fc5072b8f9df62206bf7735" alt="outputs-2"
nPad2d(3),
nn.Conv2d(64, output_channels, kernel_size=7, padding=0, bias=True),
nn.Tanh()
)
def forward(self, x):
# Apply initial convolutional layer
x = self.initial_conv(x)
# Apply downsampling layers
x = self.downsampling_1(x)
x = self.downsampling_2(x)
# Apply residual layers
x = self.residual_layers(x)
# Apply upsampling layers
x = self.upsampling_1(x)
x = self.upsampling_2(x)
# Apply final convolutional layer
x = self.final_conv(x)
return x
'''PatchGAN Discriminator'''
class Discriminator(nn.Module):
def __init__(self, input_shape):
super(Discriminator, self).__init__()
channels, height, width = input_shape
# Calculate output shape of image discriminator (PatchGAN)
self.output_shape = (1, height // 2 ** 4, width // 2 ** 4)
def discriminator_block(in_channels, out_channels, normalize=True):
"""Returns downsampling layers of each discriminator block"""
layers = [nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)]
if normalize:
layers.append(nn.InstanceNorm2d(out_channels))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
# C64 -> C128 -> C256 -> C512
self.model = nn.Sequential(
*discriminator_block(channels, out_channels=64, normalize=False),
*discriminator_block(64, out_channels=128),
*discriminator_block(128, out_channels=256),
*discriminator_block(256, out_channels=512),
nn.ZeroPad2d((1, 0, 1, 0)),
nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4, padding=1)
)
def forward(self, img):
return self.model(img)
Now around the third epoch of training, I am getting these "holes" in the generated pictures. Could anyone tell me why these are showing up and how I can prevent it? These are my hyperparameters:
'name': 'CycleGan_VanGogh_Checkpoint', 'n_epochs': 20, 'batch_size': 4, 'lr': 0.0002, 'decay_start_epoch': 19, 'b1': 0.5, 'b2': 0.999, 'img_size': 256, 'channels': 3, 'num_residual_blocks': 9, 'lambda_cyc': 10.0, 'lambda_id': 5.0}
Hi! I've been plagued by the same problems for a long time. Have u solved this problem?
Unfortunately not. I am still getting the same problem.
Unfortunately not. I am still getting the same problem.
Sorry to hear that. After my comment on this issue, I find two issues highly related to our problems and the author of CycleGAN has answered them. He thinks these artifacts are caused by mode collapse and more training data, larger loss weights of identity/cycle consistency loss, or smaller learning rate can solve this problem. I'm trying them now, maybe you can try these solutions, too. Good luck for us!
BTW, here are the issues I mentioned:
https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/725
https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/446
Thank you! That helps a lot!