MTL-Segmentation
MTL-Segmentation copied to clipboard
Meta Transfer Learning for Few Shot Semantic Segmentation using U-Net
Meta Transfer Learning for Few Shot Semantic Segmentation using U-Net
![]() |
|---|
- Meta Transfer Learning for Few Shot Semantic Segmentation using U-Net
- Requirements
- Characteristics
- Model and Technique
- Datasets
- Losses
- Code structure
- Running Experiments
- Hyperparameters and Options
- Training Plots
- Results
- Acknowledgement
Requirements
PyTorch and Torchvision needs to be installed before running the scripts, together with PIL for data-preprocessing and tqdm for showing the training progress.
To run this repository, kindly install python 3.5 and PyTorch 0.4.0 with Anaconda.
You may download Anaconda and read the installation instruction on their official website: https://www.anaconda.com/download/
Create a new environment and install PyTorch and torchvision on it:
conda create --name segfew python=3.5
conda activate segfew
conda install pytorch=0.4.0
conda install torchvision -c pytorch
Clone this repository:
git clone https://github.com/ahirsharan/MTL_Segmentation.git
Characteristics:
Model and Technique
- (U-Net) Convolutional Networks for Biomedical Image Segmentation (2015): [Paper]
- (Meta Tranfer Learning) Meta-Transfer Learning for Few-Shot Learning: [Paper]
Datasets
-
COCO Stuff: For COCO, there is two partitions, CocoStuff10k with only 10k that are used for training the evaluation, note that this dataset is outdated, can be used for small scale testing and training, and can be downloaded here. For the official dataset with all of the training 164k examples, it can be downloaded from the official website.
-
Few-Shot: For Few Shot(FSS1000), there are 1000 object classes folder each with 10 images with ground truth mask for segmentation. This dataset can be used for few shot learning and can be downloaded here.
Losses
In addition to the Cross-Entropy loss:
- Dice-Loss, which measures of overlap between two samples and can be more reflective of the training objective (maximizing the mIoU), but is highly non-convexe and can be hard to optimize.
- CE Dice loss, the sum of the Dice loss and CE, CE gives smooth optimization while Dice loss is a good indicator of the quality of the segmentation results.
- Focal Loss, an alternative version of the CE, used to avoid class imbalance where the confident predictions are scaled down.
- Lovasz Softmax lends it self as a good alternative to the Dice loss, where we can directly optimization for the mean intersection-over-union based on the convex Lovász extension of submodular losses (for more details, check the paper: The Lovász-Softmax loss).
Code Structure
The code structure is based on MTL-template and Pytorch-Segmentation.
.
|
├── FewShotPreprocessing.py # utility to organise the Few-shot data into train,test and val set
|
|
├── dataloader
| ├── dataset_loader.py # data loader for pre datasets
| ├── mdataset_loader.py # data loader for meta task dataset
| └── samplers.py # samplers for meta task dataset(Few-Shot)
|
|
├── models
| ├── mtl.py # meta-transfer class
| ├── unet_mtl.py # unet class
| └── conv2d_mtl.py # meta-transfer convolution class
|
├── trainer
| ├── pre.py # pre-train trainer class
| └── meta.py # meta-train trainer class
|
|
├── utils
| ├── gpu_tools.py # GPU tool functions
| ├── metrics.py # Metrics functions
| ├── losses.py # Loss functions
| ├── lovasz_losses.py # Lovasz Loss function
| └── misc.py # miscellaneous tool functions
|
├── main.py # the python file with main function and parameter settings
├── run_pre.py # the script to run pre-train phase
└── run_meta.py # the script to run meta-train and meta-test phases
Running Experiments
Run pretrain phase:
python run_pre.py
Run meta-train and meta-test phase:
python run_meta.py
Hyperparameters and Options
Hyperparameters and options in main.py.
model_typeThe network architecturedatasetMeta datasetphasepre-train, meta-train or meta-evalseedManual seed for PyTorch, "0" means using random seedgpuGPU iddataset_dirDirectory for the imagesmax_epochEpoch number for meta-train phasenum_batchThe number for different tasks used for meta-trainshotShot number, how many samples for one class in a taskteshotTest-Shot number, how many samples for one class in a meta test taskwayWay number, how many classes in a tasktrain_queryThe number of training samples for each class in a taskval_queryThe number of test samples for each class in a taskmeta_lr1Learning rate for SS weightsmeta_lr2Learning rate for Base learner weights (meta task)base_lrLearning rate for the inner loopupdate_stepThe number of updates for the inner loopstep_sizeThe number of epochs to reduce the meta learning ratesgammaGamma for the meta-train learning rate decayinit_weightsThe pretained weights for meta-train phasepre_init_weightsThe pretained weights for pre-train phaseeval_weightsThe meta-trained weights for meta-eval phasemeta_labelAdditional label for meta-trainpre_max_epochEpoch number for pre-train psasepre_batch_sizeBatch size for pre-train phasepre_lrLearning rate for pre-train pahsepre_gammaGamma for the preteain learning rate decaypre_step_sizeThe number of epochs to reduce the pre-train learning ratepre_custom_weight_decayWeight decay for the optimizer during pre-train
Training Plots
Pre-Train Phase
| Mean IoU | CE Loss |
|---|---|
![]() |
![]() |
Meta-Train Phase
| Mean IoU | CE Loss |
|---|---|
![]() |
![]() |
Meta-Val Phase
| Mean IoU | CE Loss |
|---|---|
![]() |
![]() |
Results
-
The Pre-trained weights for both Pre-Train and Meta Tasks can be found here pertaining to Max-IoU.
-
Some of the best results for 3-shot learning :smile: :
|-----------Image--------------Ground Truth---------------Prediction---------|
![]() |
|---|
![]() |
|---|
![]() |
|---|
![]() |
|---|
![]() |
|---|
![]() |
|---|
![]() |
|---|
![]() |
|---|
![]() |
|---|
![]() |
|---|
















