inpainting_gmcnn
inpainting_gmcnn copied to clipboard
Image Inpainting via Generative Multi-column Convolutional Neural Networks, NeurIPS2018
Image Inpainting via Generative Multi-column Convolutional Neural Networks
by Yi Wang, Xin Tao, Xiaojuan Qi, Xiaoyong Shen, Jiaya Jia.
Results on Places2, CelebA-HQ, and Paris streetview with rectangle masks.

Results on Places2 and CelebA-HQ with random strokes.
Introduction
This repository is for the NeurIPS 2018 paper, 'Image Inpainting via Generative Multi-column Convolutional Neural Networks'.
If our method is useful for your research, please consider citing:
@inproceedings{wang2018image,
title={Image Inpainting via Generative Multi-column Convolutional Neural Networks},
author={Wang, Yi and Tao, Xin and Qi, Xiaojuan and Shen, Xiaoyong and Jia, Jiaya},
booktitle={Advances in Neural Information Processing Systems},
pages={331--340},
year={2018}
}
Our framework

Partial Results
More results
Prerequisites
- Python3.5 (or higher)
- Tensorflow 1.4 (or later versions, excluding 2.x) with NVIDIA GPU or CPU
- OpenCV
- numpy
- scipy
- easydict
- Pytorch 1.0 with NVIDIA GPU or CPU
- tensorboardX
Installation
git clone https://github.com/shepnerd/inpainting_gmcnn.git
cd inpainting_gmcnn/tensorflow
or
cd inpainting_gmcnn/pytorch
For tensorflow implementations
Testing
Download pretrained models through the following links (paris_streetview, CelebA-HQ_256, CelebA-HQ_512, Places2), and unzip and put them into checkpoints/
. To test images in a folder, you can specify the folder address by the opinion --dataset_path
, and set the pretrained model path by --load_model_dir
when calling test.py
.
For example:
python test.py --dataset paris_streetview --data_file ./imgs/paris-streetview_256x256/ --load_model_dir ./checkpoints/paris-streetview_256x256_rect --random_mask 0
or
sh ./script/test.sh
Training
For a given dataset, the training is formed of two stages. We pretrain the whole network with only confidence-driven reconstruction loss first, and finetune this network using adversarial and ID-MRF loss along with the reconstruction loss after the previous phase converges.
To pretrain the network,
python train.py --dataset [DATASET_NAME] --data_file [DATASET_TRAININGFILE] --gpu_ids [NUM] --pretrain_network 1 --batch_size 16
where [DATASET_TRAININGFILE]
indicates a file storing the full paths of the training images.
Then finetune the network,
python train.py --dataset [DATASET_NAME] --data_file [DATASET_TRAININGFILE] --gpu_ids [NUM] --pretrain_network 0 --load_model_dir [PRETRAINED_MODEL_PATH] --batch_size 8
We provide both random stroke and rectangle masks in the training and testing phase. The used mask type is indicated by specifying --mask_type [rect(default)|stroke]
option when calling train.py
or test.py
.
A simple interactive inpainting GUI

Other pretrained models
CelebA-HQ_512 trained with stroke masks.
For pytorch implementations
The testing and training procedures are similar to these in the tensorflow version except some parameters are with different names.
Testing
A pretrained model: CelebA-HQ_256.
Training
Compared with the tensorflow version, this pytorch version would expect a relatively smaller batch size for training.
Other versions
Checkout the keras implementation of our paper by Tomasz Latkowski here.
Disclaimer
- For the provided pretrained models, their performance would degrade obviously when they are evaluated by a mask whose unknown areas are too large.
- As claimed in the paper, for the large datasets with thousands of categories, the model performance is unstable. Recent GAN using large-scale techniques may ease this problem.
- We did not give the full implementation of ID-MRF (in this repo) described in our original paper. The step of excluding
s
is omitted for computational efficiency. - In the pytorch version, a different GAN loss (wgan hinge loss with spectral normalization) is adopted.
Acknowledgments
Our code is partially based on Generative Image Inpainting with Contextual Attention and pix2pixHD. The implementation of id-mrf loss is borrowed from contextual loss.
Contact
Please send email to [email protected].