STGAN-pytorch icon indicating copy to clipboard operation
STGAN-pytorch copied to clipboard

STGAN: A Unified Selective Transfer Network for Arbitrary Image Attribute Editing


An unofficial PyTorch implementation of STGAN: A Unified Selective Transfer Network for Arbitrary Image Attribute Editing.



From left to right: Origin, Bangs, Blond_Hair, Brown_Hair, Bushy_Eyebrows, Eyeglasses, Male, Mouth_Slightly_Open, Mustache, Pale_Skin, Young.


Please download the CelebA dataset from this project page. Then organize the directory as:

├── data_root
│   └── image
│       ├── 000001.jpg
│       ├── 000002.jpg
│       ├── 000003.jpg
│       └── ...
│   └── anno
│       ├── list_attr_celeba.txt
│       └── ...


  • For quickly start, you can simply use the following command to train:

    CUDA_VISIBLE_DEVICES=0,1,2,3 python --config ./configs/train_stgan.yaml
  • If you want to modify some hyper-parameters, please edit them in the configuration file ./configs/train_stgan.yaml following the explanations below:

    • exp_name: the name of current experiment.
    • mode: 'train' or 'test'.
    • cuda: use CUDA or not.
    • ngpu: how many gpu cards to use. Notice: this number should be no more than the length of CUDA_VISIBLE_DEVICES list.
    • dataset: the name of dataset. Notice: you can extend other datasets.
    • data_root: the root of dataset.
    • crop_size: the crop size of images.
    • image_size: the size of input images during training.
    • g_conv_dim: the base filter numbers of convolutional layers in G.
    • d_conv_dim: the base filter numbers of convolutional layers in D.
    • d_fc_dim: the dimmension of fully-connected layers in D.
    • g_layers: the number of convolutional layers in G. Notice: same for both encoder and decoder.
    • d_layers: the number of convolutional layers in D.
    • shortcut_layers: the number of shortcut connections in G. Notice: also the number of STUs.
    • stu_kernel_size: the kernel size of convolutional layers in STU.
    • use_stu: if set to false, there will be no STU in shortcut connections.
    • one_more_conv: if set to true, there will be another convolutional layer between the decoder and generated image.
    • attrs: the list of all selected atrributes. Notice: please refer to list_attr_celeba.txt for all avaliable attributes.
    • checkpoint: the iteration step number of the checkpoint to be resumed. Notice: please set this to ~ if it's first time to train.
    • batch_size: batch size of data loader.
    • beta1: beta1 value of Adam optimizer.
    • beta2: beta2 value of Adam optimizer.
    • g_lr: the base learning rate of G.
    • d_lr: the base learning rate of D.
    • n_critic: number of D updates per each G update.
    • thres_int: the threshold of target vector during training.
    • lambda_gp: tradeoff coefficient of D_loss_gp.
    • lambda1: tradeoff coefficient of D_loss_att.
    • lambda2: tradeoff coefficient of G_loss_att.
    • lambda3: tradeoff coefficient of G_loss_rec.
    • max_iters: maximum iteration steps.
    • lr_decay_iters: iteration steps per learning rate decay.
    • summary_step: iteration steps per summary operation with tensorboardX.
    • sample_step: iteration steps per sampling operation.
    • checkpoint_step: iteration steps per checkpoint saving operation.


This code refers to the following two projects:

[1] TensorFlow implementation of STGAN

[2] PyTorch implementation of StarGAN