Adversarial-Pose-Estimation
Adversarial-Pose-Estimation copied to clipboard
A PyTorch implementation of the paper 'Adversarial PoseNet: A Structure-aware Convolutional Network for Human Pose Estimation' (https://arxiv.org/pdf/1705.00389v2.pdf)
Adversarial Pose Estimation
Abstract
This repository aims to replicate the results of this paper. The idea is to augment the human pose estimation by using a GAN-based framework, where the (conditional) generator learns the distribution P(y|x), where x is the image and y is the heatmap for the person. Typical keypoint detectors simply employ a similarity based loss (MSE or cross-entropy) on the predicted heatmaps with the ground-truth heatmaps. However, these losses can predicted smooth outputs as they are averaged over the entire spatial domain. The idea here is to make the predictions ''crisper and sharper'' by employing discriminators that differentiate between ground-truth and predicted heatmaps in 2 different ways.
Framework
An overview of the architecture is given here:
This framework consists of a generator network and two discriminator networks. Why two? One of the discriminator captures the similarity of (x, y) pairs, which is what a traditional discriminator would do, and the second discriminator compares the "quality" of heatmaps generated by the network versus the ground truth heatmaps provided. The former discriminator is called the Pose Discriminator which takes the image and heatmaps as inputs and the latter is called the Confidence Discriminator which takes only the heatmaps as input. This makes the heatmaps sharper than what is traditionally achieved by using a per-pixel loss.
The architectures used for the discriminators are same as the one described in the paper. The generator is a typical stacked hourglass architecture with intermediate supervision modules. The paper uses MSE loss function for learning the heatmaps. However, we noticed that training with MSE is very slow because the maximum per-pixel difference can be 1 and the ground truth heatmaps are sparse. Hence, we use a weighted binary cross entropy loss, to balance the ratio of positives v/s negatives. This results in much faster training and convergence. The discriminator loss is a simple binary cross entropy loss (since real and fake pairs are given in equal ratios).
Dependencies
The list of dependencies can be found in the the requirements.txt
file. Simply use pip install -r requirements.txt
to install them.
Running the code
WARNING: GAN training can be unstable, and may also depend on your pytorch/CUDA versions. If the default code doesn't work, try tuning with other parameters.
Running the code for training is fairly easy. Follow these steps.
- Go to
config/default_config.py
to edit the hyperparameters as per your choice. For your convenience, the default parameters are already set. - Download the extended LSP dataset here. Download it in your favorite directory. Your dataset directory should look like this (if the root dataset dir is
lspet_dataset/
)
lspet_dataset/
images/
im00001.jpg
im00002.jpg
...
joints.mat
README.txt
- Add this path to the
--path
parameter intrain.sh
script. This contains all the other parameters required to train the model. - Run the script.
- The pretrained file can be found in the Downloads sections of the README.
Results
We got a [email protected] value of 0.606893 over the validation dataset. We trained on binary cross entropy loss with a batch size of 1. This score is low, however, we trained it for about a day only (since we had bugs in our previous codes). Here are some qualitative results:
References
If you liked this repository, and would like to use it in your work, consider citing the original paper.
@article{DBLP:journals/corr/ChenSWLY17,
author = {Yu Chen and
Chunhua Shen and
Xiu{-}Shen Wei and
Lingqiao Liu and
Jian Yang},
title = {Adversarial PoseNet: {A} Structure-aware Convolutional Network for
Human Pose Estimation},
journal = {CoRR},
volume = {abs/1705.00389},
year = {2017},
url = {http://arxiv.org/abs/1705.00389},
archivePrefix = {arXiv},
eprint = {1705.00389},
timestamp = {Mon, 13 Aug 2018 16:47:51 +0200},
biburl = {https://dblp.org/rec/bib/journals/corr/ChenSWLY17},
bibsource = {dblp computer science bibliography, https://dblp.org}
}
We also thank Naman's repository for providing the code for PCK and PCKh metrics.
Downloads
TODO: Put a drive link to pretrained model.