soft-attention-image-captioning
soft-attention-image-captioning copied to clipboard
tensorflow implementation of show, attend and tell (ICML'15)
Soft Attention Image Captioning
Tensorflow implementation of Show, Attend and Tell presented in ICML'15.
Huge re-factor from last update, compatible with tensorflow >= r1.0
Prerequisites
- Python 2.7+
- NumPy
- Tensorflow r1.0+
- Scikit-image
- tqdm
Data
- Training: Microsoft COCO: Common Objects in Context training and validation set
Preparation
-
Clone this repo, create
data/andlog/folders:git clone https://github.com/markdtw/soft-attention-image-captioning.git cd soft-attention-image-captioning mkdir data mkdir log -
Download and extract pre-trained
Inception V4andVGG 19from tf.slim for feature extraction.
Save the ckpt files incnns/asinception_v4_imagenet.ckptandvgg_19_imagenet.ckpt. -
We need the following files in our
data/folder:coco_raw.jsoncoco_processed.jsoncoco_dictionary.pklcoco_final.jsontrain2014_vgg(inception).npyandval2014_vgg(inception).npy
These files can be generated through
utils.py, please refer to it before executing. -
If you are not able to extract the features yourself, here is the features download link:
- It may take a long time to download.
Train
Train from scratch with default settings:
python main.py --train
Train from a pre-trained model from epoch X:
python main.py --train --model_path=log/model.ckpt-X
Check out tunable arguments:
python main.py
Generate a caption
Using default(latest) model:
python main.py --generate --img_path=/path/to/image.jpg
Using model from epoch X:
python main.py --generate --img_path=/path/to/image.jpg --model_path=log/model.ckpt-X
Others
- Features extracted are around 16 + 8 GB. Make sure you have enough CPU memory when loading the data.
- GPU memory usage for batch_size 128 is around 8GB.
- Utilize
tf.while_loopfor rnn implementation,tf.slimfor feature extraction from their github page. - GRU cell is implemented, use it by setting
--use_gru=Truewhen training. - Features can be extracted through inceptionV4, if so, model.ctx_dim in
model.pyneeds to be set to (64, 1536). (other modifications are needed) - Issues are welcome!