lightweight-gan icon indicating copy to clipboard operation
lightweight-gan copied to clipboard

How to improve quality

Open vibe007 opened this issue 3 years ago • 23 comments

Thanks for putting this together! I'm having some success with this creating stylized artwork. I'm wondering what are the avenues to improve quality? It sounds like training for longer helps, along with adding attention. Is there a --network-capacity flag similar to your stylegan2 project? Should increasing the number of feature_maps fmap_max help? What about increasing the size of the latent_dim?

If we scale up to multi-GPU should we scale the learning rate a corresponding amount?

vibe007 avatar Nov 23 '20 17:11 vibe007

@vibe007 the author reported that changing the discriminator output size helped for artworks specifically https://github.com/lucidrains/lightweight-gan#discriminator-output-size other than that, you'll have to reach for the state of the art solution https://github.com/lucidrains/stylegan2-pytorch

lucidrains avatar Nov 23 '20 17:11 lucidrains

@vibe007 what size are you training at and for how long? you should keep training until it collapses

lucidrains avatar Nov 23 '20 17:11 lucidrains

So far I've tested resolutions 128 and 256, training for 150K iters (the default). I'll try training for longer since it definitely hasn't collapsed yet. The SOTA project you mentioned takes much much longer to train in my testing

I am using the recommended discriminator-output-size for artwork. Thanks!

vibe007 avatar Nov 23 '20 18:11 vibe007

@vibe007 yup, you can raise the number of training steps with --num-train-steps

lucidrains avatar Nov 23 '20 18:11 lucidrains

@vibe007 with GANs, training isn't over if the game hasn't collapsed!

lucidrains avatar Nov 23 '20 18:11 lucidrains

Thanks! Hope to update with some cool images soon...

vibe007 avatar Nov 23 '20 21:11 vibe007

Hey, great tips! I have some similar questions. Currently, I'm giving the model some very cartoon styled artwork. I'm running it in Colab with a Tesla V100-SXM2-16GB GPU. I haven't let any run to the full 150k (usually you've released a new version and I'm just too curious to see if it offers better results), even so by 60k it's getting an FID of ~100 and looks like its properly converging. Which is fantastic to me! But still, I wonder if there's something I could do to make it even more awesome. Here are the settings I'm using, see anything glaring that I should change? Thanks in advance!

lightweight_gan --data /content/images --sle_spatial --attn-res-layers [32,64] --amp --disc-output-size 5 --models_dir "path" --results_dir "path" --calculate_fid_every 10000

I'm pretty sure I'm leaving processing power on the table here with my Colab notebook. P.S. Thanks for the auto aug_prob work, I always doubted I was using the right odds there 😊

druidOfCode avatar Nov 25 '20 03:11 druidOfCode

@bckwalton good to hear! keep training! you can extend training by increasing the --num-train-steps I also added truncation to the input normal latents, which should help generation quality a bit. Do share a sample if you are allowed to do so :)

lucidrains avatar Nov 25 '20 21:11 lucidrains

Awesome! Right now everything looks a bit like soup since the last posting. 🤔 May have to change some settings. I'll report back with better results ⚡, all the images in the dataset are Cartoon characters in portrait shots and you can kinda see that here. Rather Soupy (This is step 69,000, generated with 0.12.2 [Truncation version])

druidOfCode avatar Nov 25 '20 22:11 druidOfCode

In my results (20k iterations), I see mesh-like artifacts that are very noticeable, far more than in some of the demo samples. Is it a bug or an inherent flaw of such GAN that can't really be avoided? If so, is there a way to compensate for them?

tannisroot avatar Nov 26 '20 01:11 tannisroot

@tannisroot you can try an antialiased version of this GAN by using the --antialias flag at the start of training. it'll be a bit slower though

lucidrains avatar Nov 26 '20 03:11 lucidrains

@tannisroot otherwise, just train for longer, 20k is still quite early!

lucidrains avatar Nov 26 '20 03:11 lucidrains

Back with results. After toying with some additional attention layers here are the settings I landed on for Colab. (Warning to anyone following for Colab settings though, check your GPU version before using, they like to alternate between P100 and V100, these settings are for V100. If you get P100 you won't have enough VRAM to cram this many attention layers in).

lightweight_gan --data dataset_path --sle_spatial --attn-res-layers [32,64,128] --image-size 512 --amp --disc-output-size 5 --models_dir models_path --results_dir results_path --calculate_fid_every 10000

I also sanity checked my Dataset, most of its from random Booru's and filtered by tags. I tried my best to clean out as many mistags and culled about 300/5148 images (some just outright incorrect, others whole comic pages that just happened to have a portrait shot in one of the frames).

The results are markedly improved. It's only at 52,000 epochs right now (35%) but it's actually "understandable" the FID is reporting 108. Since I changed parts of the Dataset and added to the model I can't be certain what's responsible here, but here are my results regardless in case there's something to learn from here. 😀

52,000 It's converging already Link for Gif of Latent Space exploration: Imgur: Image and Gif (52,000 epochs)

70,000 Making progress Link for Gif of Latent Space exploration: Imgur: Image and Gif (70,000 epochs)

150,000 Needs more time Link for Gif of Latent Space exploration: Imgur: Image and Gif (150,000 epochs)

druidOfCode avatar Nov 26 '20 18:11 druidOfCode

I've noticed that this GAN does not have an issue with augmentation where the augmentation leaks into the results, which is very common for https://github.com/lucidrains/stylegan2-pytorch. Is it known why it doesn't suffer from this? If so, can the state of art implementation be improved to avoid this problem? Also, can the auto-augmentation be backported as well? very nifty feature!

tannisroot avatar Nov 28 '20 18:11 tannisroot

See these two papers on how to augment without leaking the augmentations into the generator:

Karras, Tero, et al. "Training generative adversarial networks with limited data." Advances in Neural Information Processing Systems 33 (2020). Zhao, Shengyu, et al. "Differentiable augmentation for data-efficient gan training." Advances in Neural Information Processing Systems 33 (2020).

I think the main idea is we want to augment the discriminator inputs as opposed to the generator inputs, and it appears the augmentations in https://github.com/lucidrains/stylegan2-pytorch do this correctly.

You can grab this file to use these augmentations in another project: https://github.com/lucidrains/stylegan2-pytorch/blob/master/stylegan2_pytorch/diff_augment.py

vibe007 avatar Nov 28 '20 18:11 vibe007

It's actually the opposite, lightweight-gan's augmentation doesn't leak, but stylegan2-pytorch's does, but thanks for the hint, I'll try swapping the augmenting bits!

tannisroot avatar Nov 28 '20 19:11 tannisroot

The augmentation in the papers mentioned above should not leak. If you experience leaks with the implementation, maybe there is an issue with the implementation or the parameters.

See: https://github.com/lucidrains/stylegan2-pytorch#low-amounts-of-training-data

If one were to augment at a low enough probability, the augmentations will not 'leak' into the generations.

woctezuma avatar Nov 28 '20 20:11 woctezuma

If the volume of training data is high, it's possible that data augmentations can hurt image quality. (as discussed in "Training generative adversarial networks with limited data.") It's also possible that you may not be training for enough time - see https://github.com/NVlabs/stylegan2 for expected training times for styleGAN2 (it's days to weeks).

vibe007 avatar Nov 28 '20 20:11 vibe007

This model works surprisingly well even for small (less than 10 000 images), complex and highly variable datasets: scr_res And it took only about 15 hours on a single GeForce 1070Ti. Some tips to improve quality for cartoon images:

  • Use Dual Contrast Loss (--dual-contrast-loss key)
  • Entirely remove attention blocks (--attn-res-layers [] key) - yes, just with empty square brackets (without this key it adds attention to 32x32 layers). That's because attention mechanism requires a really HUGE amount of data to work well. If you have less then 600K images, attention significantly reduces visual quality of your images and also makes model overfit.
  • Replace GlobalContext block with original FastGAN realization of SEBlock. The reason is that GlobalContext contains attention mechanism, and that's not good for small datasets.

iScriptLex avatar Dec 05 '21 15:12 iScriptLex

@iScriptLex would you elaborate on that last bullet? How would I replace the GlobalContext block?

rickyars avatar Mar 29 '22 13:03 rickyars

@rickyars You should edit the file lightweight_gan.py. First, add original realization of SEBlock from FastGAN before Generator class definition (just paste this code before class Generator(nn.Module): line):

class Swish(nn.Module):
    def forward(self, feat):
        return feat * torch.sigmoid(feat)

# Self-excitation block realization from FastGAN
class SEBlock(nn.Module):
    def __init__(self, ch_in, ch_out):
        super().__init__()

        self.main = nn.Sequential(
            nn.AdaptiveAvgPool2d(4), 
            nn.Conv2d(ch_in, ch_out, 4, bias=False),
            Swish(),
            nn.Conv2d(ch_out, ch_out, 1, bias=False), 
            nn.Sigmoid() 
        )

    def forward(self, x):
        return self.main(x)

Then, in __init__ function of Generator class replace this code:

sle = GlobalContext(
    chan_in = chan_out,
    chan_out = sle_chan_out
)

with this: sle = SEBlock(ch_in = chan_out, ch_out = sle_chan_out)

iScriptLex avatar Mar 29 '22 19:03 iScriptLex

Thank you, @iScriptLex. I'll give this a try tonight. One more question for you. Whenever I run --sle_spatial it tells me:

ERROR: Could not consume arg: --sle_spatial

Any idea what I'm doing wrong?

rickyars avatar Apr 02 '22 19:04 rickyars

Any idea what I'm doing wrong?

  • d2ed25b36ae228c24a45083df6a9c04f7baf45fe

woctezuma avatar Apr 02 '22 20:04 woctezuma