ENet-ScanNet
ENet-ScanNet copied to clipboard
ENet for 2D semantic segmentation in ScanNet
ENet-ScanNet
PyTorch (v1.0.0) re-implementation of ENet: A Deep Neural Network Architecture for Real-Time Semantic Segmentation, ported from the excellent PyTorch impelentation ENet-ScanNet, which was in-turn ported from the lua-torch implementation ENet-training created by the authors.
This implementation has been tailored to suit the ScanNet dataset.
What's new from ENet?
The primary change from ENet is that this repository supports a variant ENetDepth
, which is a 2.5D (i.e., RGB + Depth) version of ENet. It takes a color image and its corresponding depth image as input, and performs semantic segmentation. ENetDepth
provides a huge boost in performance (+0.17 mIoU over ENet
), so use it whenever you can.
On an NVIDIA GeForce GTX 1060, ENetDepth
operates at a rate of about 40 Hz! Should run faster on a better GPU.
Installation
- Python 3 and pip.
- Set up a virtual environment (recommended).
- Install dependencies using pip:
pip install -r requirements.txt
.
Obtaining ScanNet
To obtain ScanNet, instructions are available in the README
of this repository. The entire dataset is huge (1.4 TB or so)! However, you could choose to download only specific scenes. For training a well-performing ENet, one would need about 75-100 scenes from ScanNet.
Pretrained models
A pretrained model ships with the repository. It is in the save directory.
IMPORTANT: Make sure to edit
save/ENetDepth-scannet20_summary.txt
and specify paths to ScanNet data, in order for the pretrained model to work.
Data preparation
Data preparation for ENet is identical to that for 3DMV.
For more information, see data preparation.
Usage
Run main.py
, the main script file used for training and/or testing the model. The code has a lot of options. Make sure to read through most of them, before training/testing.
usage: main.py [-h] [--mode {train,test,inference,full}] [--resume]
[--generate-images] [--arch {rgb,rgbd}]
[--seg-classes {nyu40,scannet20}] [--batch-size BATCH_SIZE]
[--epochs EPOCHS] [--learning-rate LEARNING_RATE]
[--beta0 BETA0] [--beta1 BETA1] [--lr-decay LR_DECAY]
[--lr-decay-epochs LR_DECAY_EPOCHS]
[--weight-decay WEIGHT_DECAY] [--dataset {scannet}]
[--dataset-dir DATASET_DIR] [--trainFile TRAINFILE]
[--valFile VALFILE] [--testFile TESTFILE] [--height HEIGHT]
[--width WIDTH] [--weighing {enet,mfb,none}]
[--class-weights-file CLASS_WEIGHTS_FILE] [--with-unlabeled]
[--workers WORKERS] [--print-step PRINT_STEP] [--imshow-batch]
[--device DEVICE] [--name NAME] [--save-dir SAVE_DIR]
[--validate-every VALIDATE_EVERY]
For help on the optional arguments run: python main.py -h
Examples: Training
python main.py -b 64 --epochs 200 --dataset-dir /path/to/scannet/scannetv2_images/ --trainFile cache/train.txt --valFile cache/val.txt --testFile cache/test.txt --print-step 25 --seg-classes scannet20 --class-weights-file cache/class_weights_scannet20.txt --name ENetDepth --lr-decay-epochs 60 -lr 1e-3 --beta0 0.7 --arch rgbd --validate-every 10
Training for 200 epochs will take about 5-6 hours on an NVIDIA GeForce GTX TITANX GPU.
Examples: Resuming training
python main.py -b 64 --epochs 200 --dataset-dir /path/to/scannet/scannetv2_images/ --trainFile cache/train.txt --valFile cache/val.txt --testFile cache/test.txt --print-step 25 --seg-classes scannet20 --class-weights-file cache/class_weights_scannet20.txt --name ENetDepth --lr-decay-epochs 60 -lr 1e-3 --beta0 0.7 --arch rgbd --validate-every 10 --resume
Examples: Inference
Once you're all trained and set, you can use inference.py
to generate the cool-looking qualitative results on top of this README
. A sample inference.py
call would look like
python inference.py --mode inference -b 2 --epochs 1 --dataset-dir /path/to/scannet/scannetv2_images/ --trainFile cache/train.txt --valFile cache/val.txt --testFile cache/test.txt --arch rgbd --print-step 1 --seg-classes scannet20 --class-weights-file cache/class_weights_scannet20.txt --name ENetDepth --generate-images
This will create a directory named ENetDepth_images
in the save
directory.
Project structure
Folders
-
data
: Contains instructions on how to download the datasets and the code that handles data loading. -
metric
: Evaluation-related metrics. -
models
: ENet model definition. -
save
: By default,main.py
will save models in this folder. The pre-trained models can also be found here.
Files
-
args.py
: Contains all command-line options. -
main.py
: Main script file used for training and/or testing the model. -
test.py
: Defines theTest
class which is responsible for testing the model. -
train.py
: Defines theTrain
class which is responsible for training the model. -
transforms.py
: Defines image transformations to convert an RGB image encoding classes to atorch.LongTensor
and vice versa.