illuminant_estimation
illuminant_estimation copied to clipboard
Deep Specialized Network for Illuminant Estimation
Illuminant Estimation
This project implements the illuminant estimation method which is presented in the paper "Deep Specialized Network for Illuminant Estimation" Project. The implementation is based on Python and TensorFlow.
Prerequisites
- python=3.6
- tensorflow=1.14
- pyzmq=19.0.1 (optional)
Training from scratch
We take the training procedure on Gehler-Shi dataset as example. Please first follow the instructions in data to preprocess the data. Then run the following commands to train three models, as the performance should be evaluated by 3-fold cross validation.
CUDA_VISIBLE_DEVICES=0 python solver.py --gs-has-loc --gs-test-set 0 &
CUDA_VISIBLE_DEVICES=1 python solver.py --gs-has-loc --gs-test-set 1 &
CUDA_VISIBLE_DEVICES=2 python solver.py --gs-has-loc --gs-test-set 2 &
NOTE: ZeroMQ is recommended for efficient training. The training for each model takes roughly 12 hours on a single GeForce GTX TITAN X gpu.
If default parameters are used during training, the model parameters will be stored in models finally and the file names look like:
--- gs568-0_bs128_lr0.02
| |- hypnet_4000000.npz
| |- selnet_4000000.npz
|
|- gs568-1_bs128_lr0.02
| |- hypnet_4000000.npz
| |- selnet_4000000.npz
|
|- gs568-2_bs128_lr0.02
|- hypnet_4000000.npz
|- selnet_4000000.npz
Test
Then run the following commands to test on the three sets:
CUDA_VISIBLE_DEVICES=0 python solver.py --gs-has-loc --gs-test-set 0 --test-only &
CUDA_VISIBLE_DEVICES=1 python solver.py --gs-has-loc --gs-test-set 1 --test-only &
CUDA_VISIBLE_DEVICES=2 python solver.py --gs-has-loc --gs-test-set 2 --test-only &
Pre-trained models
Pre-trained models can be downloaded from the following links. Please unzip the files inside models.
| Link | Description |
|---|---|
| OneDrive | Trained for 3-fold cross validation |
| OneDrive | Trained on all images |
The estimated illuminants for local patches of each image will be stored in preds. Finally run the following command to get global predictions for each image:
python test_preds.py --weighted-median