progan-pytorch icon indicating copy to clipboard operation
progan-pytorch copied to clipboard

Progessive Growing of GANs in Pytorch 1.1

Progressive Growing of GANs

A pytorch implementation of the popular paper "Progressive growing of gans for improved quality, stability, and variation" (Official tensorflow code).

Current progress

Results from a 256-filter model on 128x128:

Features

  • Fully implemented progressive growing of GANs to reproduce the results on the CelebA-HQ dataset.
  • Use of WGAN-GP loss
  • Easy-to-use config files to change hyperparameters for testing
  • Supports both CelebA-HQ and MNIST
  • High performance data pre-processing pipeline
  • DataParallel support to run on multi-gpu (single node) systems
  • Mixed precision support with Apex AMP. We recommend to use optimization level O1
  • Loading and saving of checkpoints to stop and resume training

Requirements

  • Pytorch >= 1.0
  • Apex AMP
  • packages in requirements

To reproduce our results, we recommend you to use a docker environment defined in the Dockerfile

Training CelebA-HQ

  1. Download and pre-process the celebA-HQ dataset. We recommend using the following code: https://github.com/nperraud/download-celebA-HQ

  2. Pre-process the dataset and save images in different image sizes. By default from 4 to 1024).

python3 src/data_tools/generate_datasets.py --source_path data/celebA-HQ --target_path /path/to/extracted/celebA-HQ
  1. Define the hyperparameters. Default config file uses the default hyperparameters from the paper with maximum 256 filters in each convolution.

  2. Start training

python3 src/train.py models/default/config.yml

TODO

  • WGAN-GP loss + AC-GAN (as presented in the paper) for class conditional datasets
  • CIFAR-10 & LSUN datasets

Reference implementation