CapsNetsLASeg
CapsNetsLASeg copied to clipboard
Capsule Networks and Convolutional Neural Networks for the Automated Segmentation of Left Atrium in Cardiac MRI
Capsule Networks for the Automated Segmentation of Left Atrium in Cardiac MRI
Introduction
Comparing 2D capsule networks to 2D convolutional neural networks for automated left atrium segmentation using Keras and Tensorflow. The proposed U-CapsNet (adding a U-Net feature extractor for the SegCaps) is based off Capsules for Object Segmentation's SegCaps architecture. The convolutional neural networks used are basic U-Nets with no residual connections and are based off of the nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation, which actually won the challenge that this dataset was taken from.
Dataset
The networks were trained and tested on the left atrium segmentation dataset (Task 2) from the 2018 Medical Segmentation Decathlon. It's a small dataset with only 20 labeled volumes and 10 test volumes collected and labeled by King College London. The volumes are comprised of mono-modal MRIs, which have corresponding binary groundtruths (left atrium and background). The voxel spacing was a constant 1.3700000047683716 mm by 1.25 mm by 1.25 mm
where the slice thickness was transposed to the top. The main experiments do not resample them because the spacing is constant throughout the dataset and anisotropic spacing does not really affect 2D neural nets.
Installation/Setup
Dependencies
These are automatically installed through the Regular repository installation
below:
- numpy>=1.10.2
- keras
- tensorflow
- nibabel
- batchgenerators
- sklearn
Needs to be installed separately:
Regular repository installation
git clone https://github.com/jchen42703/CapsNetsLASeg.git
cd CapsNetsLASeg
pip install .
keras_med_io installation
(If you don't already have it installed)
git clone https://github.com/jchen42703/keras_med_io.git
cd keras_med_io
pip install .
Dataset Installation
Run the following line in terminal to automatically download the dataset, Task02_Heart
, in dset_path
and create a new preprocessed dataset, Preprocessed_Heart
in output_path
.
python ./CapsNetsLASeg/scripts/download_preprocess.py --dset_path=dset_path --output_path=output_path
Please make sure to check the actual script for more indepth documentation.
Quick Tutorial
An example of how the scripts can be run is in examples/Running_CapsNetsLASeg_Scripts_[Demo].ipynb
. Also, the documentation for the scripts is mainly in each script directly in scripts/
.
Training
Run the script below with the appropriate arguments to train your desired model:
python ./CapsNetsLASeg/scripts/training.py --weights_dir=weights_dir --dset_path=./Preprocessed_Heart --model_name=name --epochs=n_epochs
Required Arguments:
-
--weights_dir
: Path to the base directory where you want to save your weights (does not include the .h5 filename) -
--dset_path
: Path to the base directory where the imagesTr and labelsTr directory are. -
--model_name
: Name of the model you want to train- Either:
cnn
,capsr3
,ucapsr3
, orcnn-simple
- Either:
-
--epochs
: Number of epochs
You can view the other optional arguments, such as batch_size
, n_pos
, lr
, etc. in the original script.
Inference
Once you're done training, you can now predict and evaluate on your separated test set. Note that in weights_dir
, you'll see model_name_fold1.json
. This is a dictionary representing the file splits for a single fold of cross validation, and the script below will use that to predict and evaluate on the separated test set (different from the test set for the actual challenge, imagesTs
).
!python ./CapsNetsLASeg/scripts/inference.py --weights_path=./weights.h5 --raw_dset_path=./Task02_Heart --model_name=name --fold_json_path="./capsnetslaseg_fold1.json" --batch_size=17 --save_dir="./pred"
Required Arguments:
-
--weights_path
: Path to the saved weights (a .h5 file). -
--raw_dset_path
: Path to the base directory (Task02_Heart
) where the unpreprocessed imagesTr and labelsTr directory are. -
--model_name
: Name of the model you want to train- Either:
cnn
,capsr3
,ucapsr3
, orcnn-simple
- Either:
-
--fold_json_path
: Path to the json with the filenames split.
Similar to the previous section, you can view the other optional arguments, such as batch_size
, save_dir
, decoder
, etc. in the original script.
Results
Neural Network | Parameters | Test Dice | Weights |
---|---|---|---|
U-Net | 27,671,926 | 0.89 | https://drive.google.com/open?id=1G_0sgIig5wcJ-nrIpZdCsOB1uFXwaX23 |
U-Net (Baseline) | 4,434,385 | 0.866 | https://drive.google.com/open?id=1Xm-TV1apc_LK8wJrDZC5pBGXeirE5S57 |
U-CapsNet | 4,542,400 | 0.876 | https://drive.google.com/open?id=1ji0U9bd0GoLdvXwK9ARTwpUum1NiNiQ- |
SegCaps | 1,416,112 | 0.81 | https://drive.google.com/open?id=1k8f474s4rNwggtp3SWRTQXfLY-f85zvh |
Expanding to Other Datasets
Note that this repository is specifically catered towards the binary segmentation of mono-modal MRIs. However, the AdaptiveUNet
architecture and the loss functions in metrics can be extended to multi-class problems.
About Keras
Keras is a minimalist, highly modular neural networks library, written in Python and capable of running on top of either TensorFlow or Theano. It was developed with a focus on enabling fast experimentation. Being able to go from idea to result with the least possible delay is key to doing good research.
Use Keras if you need a deep learning library that:
- allows for easy and fast prototyping (through total modularity, minimalism, and extensibility).
- supports both convolutional networks and recurrent networks, as well as combinations of the two.
- supports arbitrary connectivity schemes (including multi-input and multi-output training).
- runs seamlessly on CPU and GPU.
Read the documentation: Keras.io
Keras is compatible with: Python 2.7-3.5.