CycleGAN icon indicating copy to clipboard operation
CycleGAN copied to clipboard

Holes showing up while training

Open bear96 opened this issue 1 year ago • 4 comments

Hi, so I've been trying to replicate your paper by creating a PyTorch model from scratch and training it on the original vangogh2photo dataset provided by Berkeley. Admittedly, it's for fun and not for any research, but I still hate it when it doesn't work out. So, this is the architecture of the model I've made:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import random

'''This is a memory storage that stores 50 previously created images.
This is in accordance with the paper that introduced CycleGAN, Unpaired Image to Image translation.'''
class ReplayBuffer:
    def __init__(self, max_size=50):
        assert max_size > 0, "Empty buffer."
        self.max_size = max_size
        self.data = []

    def push_and_pop(self, data):
        to_return = []
        for element in data.data:
            element = torch.unsqueeze(element, 0)
            if len(self.data) < self.max_size:
                self.data.append(element)
                to_return.append(element)
            else:
                # Returns newly added image with a probability of 0.5.
                if random.uniform(0, 1) > 0.5:
                    i = random.randint(0, self.max_size - 1)
                    to_return.append(self.data[i].clone())
                    self.data[
                        i
                    ] = element  # replaces the older image with the newly generated image.
                else:
                    # Otherwise, it sends an older generated image and
                    to_return.append(element)
        return Variable(torch.cat(to_return))
    

'''Linear learning rate scheduler.'''

class LambdaLR:
    def __init__(self, n_epochs, offset, decay_start_epoch):
        if (n_epochs - decay_start_epoch) < 0:
            raise Exception("Decay should start before training ends. Change decay_start_epoch to a value less than {}.".format(n_epochs))
        self.n_epochs = n_epochs
        self.offset = offset
        self.decay_start_epoch = decay_start_epoch

    def step(self, epoch):
        return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch) / (self.n_epochs - self.decay_start_epoch)
    
'''Single Residual Block. InstanceNorm2d produces blob artefacts. Consider changing it to modulated convolutions later.
Currently using augmentation and a low number of epochs to stop Generator from producing artefacts.'''

class ResNetBlock(nn.Module):
    def __init__(self, channels):
        super(ResNetBlock, self).__init__()

        self.conv_block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(channels, channels, kernel_size=3, padding=0, bias=True),
            nn.InstanceNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(channels, channels, kernel_size=3, padding=0, bias=True),
            nn.InstanceNorm2d(channels)
        )

    def forward(self, x):
        return x + self.conv_block(x)

class GeneratorResNet(nn.Module):
    def __init__(self, input_channels, output_channels, num_resnet_blocks=9):
        super(GeneratorResNet, self).__init__()

        # Initial convolutional layer
        self.initial_conv = nn.Sequential(
            nn.ReflectionPad2d(3),
            nn.Conv2d(input_channels, 64, kernel_size=7, padding=0, bias=True),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        )

        # Downsampling layers
        self.downsampling_1 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=True),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True)
        )

        self.downsampling_2 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1, bias=True),
            nn.InstanceNorm2d(256),
            nn.ReLU(inplace=True)
        )

        # Residual layers
        self.residual_layers = nn.Sequential(
            *[ResNetBlock(256) for _ in range(num_resnet_blocks)]
        )

        # Upsampling layers
        self.upsampling_1 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True)
        )

        self.upsampling_2 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        )

        # Final convolutional layer
        self.final_conv = nn.Sequential(
            nn.Reflectio
![outputs-2](https://user-images.githubusercontent.com/73417041/231626093-fc0ce59a-7e10-42b0-b38c-ab672f70b0bf.png)
nPad2d(3),
            nn.Conv2d(64, output_channels, kernel_size=7, padding=0, bias=True),
            nn.Tanh()
        )

    def forward(self, x):
        # Apply initial convolutional layer
        x = self.initial_conv(x)

        # Apply downsampling layers
        x = self.downsampling_1(x)
        x = self.downsampling_2(x)

        # Apply residual layers
        x = self.residual_layers(x)

        # Apply upsampling layers
        x = self.upsampling_1(x)
        x = self.upsampling_2(x)

        # Apply final convolutional layer
        x = self.final_conv(x)

        return x

    
'''PatchGAN Discriminator'''

class Discriminator(nn.Module):
    def __init__(self, input_shape):
        super(Discriminator, self).__init__()

        channels, height, width = input_shape

        # Calculate output shape of image discriminator (PatchGAN)
        self.output_shape = (1, height // 2 ** 4, width // 2 ** 4)

        def discriminator_block(in_channels, out_channels, normalize=True):
            """Returns downsampling layers of each discriminator block"""
            layers = [nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_channels))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        # C64 -> C128 -> C256 -> C512
        self.model = nn.Sequential(
            *discriminator_block(channels, out_channels=64, normalize=False),
            *discriminator_block(64, out_channels=128),
            *discriminator_block(128, out_channels=256),
            *discriminator_block(256, out_channels=512),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4, padding=1)
        )

    def forward(self, img):
        return self.model(img)

Now around the third epoch of training, I am getting these "holes" in the generated pictures. Could anyone tell me why these are showing up and how I can prevent it? These are my hyperparameters:

'name': 'CycleGan_VanGogh_Checkpoint', 'n_epochs': 20, 'batch_size': 4, 'lr': 0.0002, 'decay_start_epoch': 19, 'b1': 0.5, 'b2': 0.999, 'img_size': 256, 'channels': 3, 'num_residual_blocks': 9, 'lambda_cyc': 10.0, 'lambda_id': 5.0}

outputs-2

bear96 avatar Apr 13 '23 01:04 bear96

Hi! I've been plagued by the same problems for a long time. Have u solved this problem?

Joechann0831 avatar Jun 13 '23 06:06 Joechann0831

Unfortunately not. I am still getting the same problem.

bear96 avatar Jun 13 '23 13:06 bear96

Unfortunately not. I am still getting the same problem.

Sorry to hear that. After my comment on this issue, I find two issues highly related to our problems and the author of CycleGAN has answered them. He thinks these artifacts are caused by mode collapse and more training data, larger loss weights of identity/cycle consistency loss, or smaller learning rate can solve this problem. I'm trying them now, maybe you can try these solutions, too. Good luck for us!

BTW, here are the issues I mentioned:

https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/725

https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/446

Joechann0831 avatar Jun 13 '23 13:06 Joechann0831

Thank you! That helps a lot!

bear96 avatar Jun 13 '23 13:06 bear96